Загрузка...

Скрипт Ограничить количество подключений к порту

Тема в разделе Go создана пользователем krisssss 4 июн 2025 в 03:29. 98 просмотров

Загрузка...
  1. krisssss
    krisssss Автор темы 4 июн 2025 в 03:29 СЕО-Услуги | aida.biz | Лучшая СММ панель 10 719 24 дек 2024
    Код

    package main

    import (
    "context"
    "errors"
    "flag"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer
    IsTestSuite bool
    TargetPort int
    ProxyPort int
    MaxConnections int
    }

    var (
    proxyPortFlag = flag.Int(
    "proxy-port", 0,
    "Порт ******-сервера для входящих соединений, обязательный",
    )
    targetPortFlag = flag.Int(
    "target-port", 443,
    "Порт локального VLESS-сервера (по умолчанию 443)",
    )
    maxConnsFlag = flag.Int(
    "limit", 0,
    "Максимальное число одновременных подключений (обязательный, >0)",
    )
    testLingerFlag = flag.Bool(
    "test-linger", false,
    "Устанавливать TCP Linger=0 для тестов",
    )
    )

    func setupConfig(logger *slog.Logger) (*Config, error) {
    if *proxyPortFlag <= 0 || *proxyPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-proxy-port должен быть в диапазоне 1-65535, указан: %d", *proxyPortFlag)
    }
    if *targetPortFlag <= 0 || *targetPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-target-port должен быть в диапазоне 1-65535, указан: %d", *targetPortFlag)
    }
    if *maxConnsFlag <= 0 {
    flag.Usage()
    return nil, fmt.Errorf("-limit должен быть > 0, указан: %d", *maxConnsFlag)
    }

    return &Config{
    Logger: logger,
    Dialer: &net.Dialer{Timeout: 30 * time.Second},
    ProxyPort: *proxyPortFlag,
    TargetPort: *targetPortFlag,
    MaxConnections: *maxConnsFlag,
    IsTestSuite: *testLingerFlag,
    }, nil
    }

    func main() {
    flag.Usage = func() {
    fmt.Fprintf(os.Stderr, "Использование: %s [OPTIONS]\n", os.Args[0])
    flag.PrintDefaults()
    }
    flag.Parse()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
    defer stop()

    if err := run(ctx, logger); err != nil && !errors.Is(err, context.Canceled) {
    logger.Warn("application error", slog.String("error", err.Error()))
    }
    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    cfg, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    listenAddr := fmt.Sprintf(":%d", cfg.ProxyPort)
    forwardAddr := fmt.Sprintf("127.0.0.1:%d", cfg.TargetPort)

    g, ctx := errgroup.WithContext(ctx)
    g.SetLimit(cfg.MaxConnections)

    listener, err := net.Listen(TransportProtocol, listenAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }
    defer listener.Close()

    logger.Info("proxy ready",
    slog.String("listen", listenAddr),
    slog.String("forwardTo", forwardAddr),
    slog.Int("maxConns", cfg.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
    break
    }
    logger.Warn("accept error", slog.String("error", err.Error()))
    continue
    }

    if ok := g.TryGo(func() error {
    defer conn.Close()
    return proxy(ctx, cfg, conn, forwardAddr)
    }); !ok {
    // сброс без логирования при превышении лимита
    if tcpConn, ok := conn.(*net.TCPConn); ok {
    tcpConn.SetLinger(0)
    tcpConn.SetKeepAlive(false)
    }
    conn.Close()
    continue
    }
    }

    return g.Wait()
    }

    func proxy(ctx context.Context, cfg *Config, client net.Conn, forwardAddr string) error {
    target, err := cfg.Dialer.DialContext(ctx, TransportProtocol, forwardAddr)
    if err != nil {
    return fmt.Errorf("dial target: %w", err)
    }
    defer target.Close()

    go io.Copy(target, client)
    _, err = io.Copy(client, target)
    return err
    }




     
    4 июн 2025 в 03:29 Изменено
  2. yoona
    Дело было вечером, делать было нечего

    В идеале отдавать ConnectionRefused, но мне стало лень возиться с этим, и хотелось сохранить изначальную логику редиректа

    Немного косметических процедур и рефакторинга, но разделять по файлам уж было лень в силу того, что сюда структуру потом не залить :(

    В тесте используется финалайзер, что потом может сказаться негативно на большем количестве тестов, ибо не отдает порт системе до закрытия теста. Можно переписать что бы оно чисто в массиве их держало тогда, но мне было лень.

    Нужно для того, что бы GC сам не закрывал порт, ибо тогда получился бы флапающий тест
    Код
    package main

    import (
    "context"
    "errors"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "strconv"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    envPrefix = "PROXY_APP"
    EnvKeyTargetAddr = envPrefix + "TARGET_ADDR"
    EnvKeyMaxConnections = envPrefix + "MAX_CONNECTIONS"
    EnvKeyProxyAddr = envPrefix + "PROXY_ADDR"
    EnvKeyTestLinger = envPrefix + "TEST_LINGER"

    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer

    IsTestSuite bool
    TargetAddr string
    ProxyAddr string
    MaxConnections int
    }

    func setupConfig(logger *slog.Logger) (*Config, error) {
    maxConnectionsRaw := os.Getenv(EnvKeyMaxConnections)
    maxConnections, err := strconv.Atoi(maxConnectionsRaw)
    if err != nil {
    return nil, fmt.Errorf("atoi max connections: %w", err)
    }

    config := &Config{
    Logger: logger,
    Dialer: &net.Dialer{},
    TargetAddr: os.Getenv(EnvKeyTargetAddr),
    ProxyAddr: os.Getenv(EnvKeyProxyAddr),
    IsTestSuite: os.Getenv(EnvKeyTestLinger) != "",
    MaxConnections: maxConnections,
    }

    if config.TargetAddr == "" || config.ProxyAddr == "" || config.MaxConnections == 0 {
    return nil, fmt.Errorf("one of the required arguments missing") // Мне лень их было переписывать, а лучше вообще валидатором пройтись, сами уже
    }

    return config, nil
    }

    func main() {
    errGroup, ctx := errgroup.WithContext(context.Background())

    ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
    defer cancel()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    errGroup.Go(func() error {
    if err := run(ctx, logger); err != nil {
    return fmt.Errorf("run: %w", err)
    }

    return nil
    })

    if err := errGroup.Wait(); err != nil {
    logger.Warn("failed to run application", slog.String("err", err.Error()))
    }

    <-ctx.Done()

    // Лучше заменить на какую нибудь имплементацию Closer'a, но мне лень стало ее сюда тащить
    // И потом контексты поменять на более 'чистые', что бы красивше завершать через клозеры, а не насильно контекстом
    <-time.After(time.Second * 5)

    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    config, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    errGroup := errgroup.Group{}
    errGroup.SetLimit(config.MaxConnections)

    listenConfig := net.ListenConfig{}
    listener, err := listenConfig.Listen(context.Background(), TransportProtocol, config.ProxyAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }

    defer listener.Close()

    config.Logger.Info(
    "ready to accept connections",
    slog.String("proxyAddr", config.ProxyAddr),
    slog.String("targetAddr", config.TargetAddr),
    slog.Int("limit", config.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    config.Logger.Warn("failed to accept connection", slog.String("err", err.Error()))
    if errors.Is(err, net.ErrClosed) {
    break
    }

    continue
    }

    if config.IsTestSuite {
    if tcp, ok := conn.(*net.TCPConn); ok {
    tcp.SetLinger(0)
    }
    }

    config.Logger.Debug("accepted connection", slog.String("addr", conn.RemoteAddr().String()))

    if ok := errGroup.TryGo(func() error {
    if err := handleConn(ctx, config, conn); err != nil {
    config.Logger.Warn("handle connection", slog.String("err", err.Error()))
    }

    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Debug("closing connection", slog.String("addr", conn.RemoteAddr().String()))

    return nil
    }); !ok {
    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Warn("exceeded limit", slog.String("remoteAddr", conn.RemoteAddr().String()), slog.Int("limit", config.MaxConnections))
    }

    }

    return nil
    }

    func handleConn(ctx context.Context, config *Config, client net.Conn) error {
    target, err := config.Dialer.DialContext(ctx, TransportProtocol, config.TargetAddr)
    if err != nil {
    return fmt.Errorf("dial: %w", err)
    }

    defer target.Close()

    config.Logger.Debug("starting connection redirect", slog.String("addr", client.RemoteAddr().String()))

    go io.Copy(target, client)
    io.Copy(client, target)

    return nil
    }
    Код
    package main

    import (
    "fmt"
    "log/slog"
    "net"
    "os"
    "runtime"
    "strconv"
    "sync/atomic"
    "testing"
    "time"

    "github.com/stretchr/testify/require"
    )

    func TestApp(t *testing.T) {
    const (
    targetPort = 40066
    proxyPort = 40077
    maxConnections = 3

    expectedDeclinedConnections = 10
    expectedBytesToRead = 8
    )

    logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
    Level: slog.LevelDebug,
    }))

    targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)

    require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort)))
    require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr))
    require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections)))
    require.NoError(t, os.Setenv(EnvKeyTestLinger, "true"))

    go func() {
    require.NoError(t, run(t.Context(), logger))
    }()

    listener, err := net.Listen(TransportProtocol, targetAddr)
    require.NoError(t, err)

    passedConnections := atomic.Int32{}
    declinedConnections := atomic.Int32{}

    go func() {
    for {
    conn, err := listener.Accept()
    if err != nil {
    continue
    }

    runtime.SetFinalizer(conn, nil)

    receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2))
    require.Equal(t, receivedBytesNum, expectedBytesToRead)

    passedConnections.Add(1)
    }

    }()

    for range maxConnections + expectedDeclinedConnections {
    conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second)
    require.NoError(t, err)

    <-time.After(time.Second / 2)

    if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil {
    declinedConnections.Add(1)

    continue
    }

    runtime.SetFinalizer(conn, nil)
    }

    require.Equal(t, int32(maxConnections), passedConnections.Load())
    require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load())
    }
     
    4 июн 2025 в 23:35 Изменено
    1. Посмотреть предыдущие комментарии (1)
    2. krisssss Автор темы
    3. krisssss Автор темы
Top