Загрузка...

Скрипт Ограничить количество подключений к порту

Тема в разделе Go создана пользователем krisssss 4 июн 2025 в 03:29. 89 просмотров

Загрузка...
  1. krisssss
    krisssss Автор темы 4 июн 2025 в 03:29 СЕО-Услуги | aida.biz | Лучшая СММ панель 10 714 24 дек 2024
    Код

    package main

    import (
    "context"
    "errors"
    "flag"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer
    IsTestSuite bool
    TargetPort int
    ProxyPort int
    MaxConnections int
    }

    var (
    proxyPortFlag = flag.Int(
    "proxy-port", 0,
    "Порт ******-сервера для входящих соединений, обязательный",
    )
    targetPortFlag = flag.Int(
    "target-port", 443,
    "Порт локального VLESS-сервера (по умолчанию 443)",
    )
    maxConnsFlag = flag.Int(
    "limit", 0,
    "Максимальное число одновременных подключений (обязательный, >0)",
    )
    testLingerFlag = flag.Bool(
    "test-linger", false,
    "Устанавливать TCP Linger=0 для тестов",
    )
    )

    func setupConfig(logger *slog.Logger) (*Config, error) {
    if *proxyPortFlag <= 0 || *proxyPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-proxy-port должен быть в диапазоне 1-65535, указан: %d", *proxyPortFlag)
    }
    if *targetPortFlag <= 0 || *targetPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-target-port должен быть в диапазоне 1-65535, указан: %d", *targetPortFlag)
    }
    if *maxConnsFlag <= 0 {
    flag.Usage()
    return nil, fmt.Errorf("-limit должен быть > 0, указан: %d", *maxConnsFlag)
    }

    return &Config{
    Logger: logger,
    Dialer: &net.Dialer{Timeout: 30 * time.Second},
    ProxyPort: *proxyPortFlag,
    TargetPort: *targetPortFlag,
    MaxConnections: *maxConnsFlag,
    IsTestSuite: *testLingerFlag,
    }, nil
    }

    func main() {
    flag.Usage = func() {
    fmt.Fprintf(os.Stderr, "Использование: %s [OPTIONS]\n", os.Args[0])
    flag.PrintDefaults()
    }
    flag.Parse()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
    defer stop()

    if err := run(ctx, logger); err != nil && !errors.Is(err, context.Canceled) {
    logger.Warn("application error", slog.String("error", err.Error()))
    }
    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    cfg, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    listenAddr := fmt.Sprintf(":%d", cfg.ProxyPort)
    forwardAddr := fmt.Sprintf("127.0.0.1:%d", cfg.TargetPort)

    g, ctx := errgroup.WithContext(ctx)
    g.SetLimit(cfg.MaxConnections)

    listener, err := net.Listen(TransportProtocol, listenAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }
    defer listener.Close()

    logger.Info("proxy ready",
    slog.String("listen", listenAddr),
    slog.String("forwardTo", forwardAddr),
    slog.Int("maxConns", cfg.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
    break
    }
    logger.Warn("accept error", slog.String("error", err.Error()))
    continue
    }

    if ok := g.TryGo(func() error {
    defer conn.Close()
    return proxy(ctx, cfg, conn, forwardAddr)
    }); !ok {
    // сброс без логирования при превышении лимита
    if tcpConn, ok := conn.(*net.TCPConn); ok {
    tcpConn.SetLinger(0)
    tcpConn.SetKeepAlive(false)
    }
    conn.Close()
    continue
    }
    }

    return g.Wait()
    }

    func proxy(ctx context.Context, cfg *Config, client net.Conn, forwardAddr string) error {
    target, err := cfg.Dialer.DialContext(ctx, TransportProtocol, forwardAddr)
    if err != nil {
    return fmt.Errorf("dial target: %w", err)
    }
    defer target.Close()

    go io.Copy(target, client)
    _, err = io.Copy(client, target)
    return err
    }




     
    4 июн 2025 в 03:29 Изменено
  2. yoona
    Дело было вечером, делать было нечего

    В идеале отдавать ConnectionRefused, но мне стало лень возиться с этим, и хотелось сохранить изначальную логику редиректа

    Немного косметических процедур и рефакторинга, но разделять по файлам уж было лень в силу того, что сюда структуру потом не залить :(

    В тесте используется финалайзер, что потом может сказаться негативно на большем количестве тестов, ибо не отдает порт системе до закрытия теста. Можно переписать что бы оно чисто в массиве их держало тогда, но мне было лень.

    Нужно для того, что бы GC сам не закрывал порт, ибо тогда получился бы флапающий тест
    Код
    package main

    import (
    "context"
    "errors"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "strconv"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    envPrefix = "PROXY_APP"
    EnvKeyTargetAddr = envPrefix + "TARGET_ADDR"
    EnvKeyMaxConnections = envPrefix + "MAX_CONNECTIONS"
    EnvKeyProxyAddr = envPrefix + "PROXY_ADDR"
    EnvKeyTestLinger = envPrefix + "TEST_LINGER"

    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer

    IsTestSuite bool
    TargetAddr string
    ProxyAddr string
    MaxConnections int
    }

    func setupConfig(logger *slog.Logger) (*Config, error) {
    maxConnectionsRaw := os.Getenv(EnvKeyMaxConnections)
    maxConnections, err := strconv.Atoi(maxConnectionsRaw)
    if err != nil {
    return nil, fmt.Errorf("atoi max connections: %w", err)
    }

    config := &Config{
    Logger: logger,
    Dialer: &net.Dialer{},
    TargetAddr: os.Getenv(EnvKeyTargetAddr),
    ProxyAddr: os.Getenv(EnvKeyProxyAddr),
    IsTestSuite: os.Getenv(EnvKeyTestLinger) != "",
    MaxConnections: maxConnections,
    }

    if config.TargetAddr == "" || config.ProxyAddr == "" || config.MaxConnections == 0 {
    return nil, fmt.Errorf("one of the required arguments missing") // Мне лень их было переписывать, а лучше вообще валидатором пройтись, сами уже
    }

    return config, nil
    }

    func main() {
    errGroup, ctx := errgroup.WithContext(context.Background())

    ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
    defer cancel()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    errGroup.Go(func() error {
    if err := run(ctx, logger); err != nil {
    return fmt.Errorf("run: %w", err)
    }

    return nil
    })

    if err := errGroup.Wait(); err != nil {
    logger.Warn("failed to run application", slog.String("err", err.Error()))
    }

    <-ctx.Done()

    // Лучше заменить на какую нибудь имплементацию Closer'a, но мне лень стало ее сюда тащить
    // И потом контексты поменять на более 'чистые', что бы красивше завершать через клозеры, а не насильно контекстом
    <-time.After(time.Second * 5)

    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    config, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    errGroup := errgroup.Group{}
    errGroup.SetLimit(config.MaxConnections)

    listenConfig := net.ListenConfig{}
    listener, err := listenConfig.Listen(context.Background(), TransportProtocol, config.ProxyAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }

    defer listener.Close()

    config.Logger.Info(
    "ready to accept connections",
    slog.String("proxyAddr", config.ProxyAddr),
    slog.String("targetAddr", config.TargetAddr),
    slog.Int("limit", config.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    config.Logger.Warn("failed to accept connection", slog.String("err", err.Error()))
    if errors.Is(err, net.ErrClosed) {
    break
    }

    continue
    }

    if config.IsTestSuite {
    if tcp, ok := conn.(*net.TCPConn); ok {
    tcp.SetLinger(0)
    }
    }

    config.Logger.Debug("accepted connection", slog.String("addr", conn.RemoteAddr().String()))

    if ok := errGroup.TryGo(func() error {
    if err := handleConn(ctx, config, conn); err != nil {
    config.Logger.Warn("handle connection", slog.String("err", err.Error()))
    }

    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Debug("closing connection", slog.String("addr", conn.RemoteAddr().String()))

    return nil
    }); !ok {
    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Warn("exceeded limit", slog.String("remoteAddr", conn.RemoteAddr().String()), slog.Int("limit", config.MaxConnections))
    }

    }

    return nil
    }

    func handleConn(ctx context.Context, config *Config, client net.Conn) error {
    target, err := config.Dialer.DialContext(ctx, TransportProtocol, config.TargetAddr)
    if err != nil {
    return fmt.Errorf("dial: %w", err)
    }

    defer target.Close()

    config.Logger.Debug("starting connection redirect", slog.String("addr", client.RemoteAddr().String()))

    go io.Copy(target, client)
    io.Copy(client, target)

    return nil
    }
    Код
    package main

    import (
    "fmt"
    "log/slog"
    "net"
    "os"
    "runtime"
    "strconv"
    "sync/atomic"
    "testing"
    "time"

    "github.com/stretchr/testify/require"
    )

    func TestApp(t *testing.T) {
    const (
    targetPort = 40066
    proxyPort = 40077
    maxConnections = 3

    expectedDeclinedConnections = 10
    expectedBytesToRead = 8
    )

    logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
    Level: slog.LevelDebug,
    }))

    targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)

    require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort)))
    require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr))
    require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections)))
    require.NoError(t, os.Setenv(EnvKeyTestLinger, "true"))

    go func() {
    require.NoError(t, run(t.Context(), logger))
    }()

    listener, err := net.Listen(TransportProtocol, targetAddr)
    require.NoError(t, err)

    passedConnections := atomic.Int32{}
    declinedConnections := atomic.Int32{}

    go func() {
    for {
    conn, err := listener.Accept()
    if err != nil {
    continue
    }

    runtime.SetFinalizer(conn, nil)

    receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2))
    require.Equal(t, receivedBytesNum, expectedBytesToRead)

    passedConnections.Add(1)
    }

    }()

    for range maxConnections + expectedDeclinedConnections {
    conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second)
    require.NoError(t, err)

    <-time.After(time.Second / 2)

    if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil {
    declinedConnections.Add(1)

    continue
    }

    runtime.SetFinalizer(conn, nil)
    }

    require.Equal(t, int32(maxConnections), passedConnections.Load())
    require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load())
    }
     
    4 июн 2025 в 23:35 Изменено
    1. yoona
      yoona, к слову, этот кусок кода не очень прям работоспособный, как по идеи не "заметит" закрытие коннекта, лучше реализовать свои пайпы, где ты будешь вручную редиректить запись и чтение⁡. Но в рамках рефакторинга мне было слишком лень уже это фиксить
      Код
      go io.Copy(target, client)
      io.Copy(client, target)
      Тест пройдет если пофиксить закрытие. (Там есть немного потенциала для флапа на
      ⁡<-time.After()
      ⁡, но мне уже слишком лень было делать как белые люди)

      Код
      package main

      import (
      "fmt"
      "log/slog"
      "net"
      "os"
      "runtime"
      "strconv"
      "sync/atomic"
      "testing"
      "time"

      "github.com/stretchr/testify/require"
      )

      func TestApp(t *testing.T) {
      const (
      targetPort = 40066
      proxyPort = 40077
      maxConnections = 3

      expectedDeclinedConnections = 10
      expectedBytesToRead = 8
      )

      logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
      Level: slog.LevelDebug,
      }))

      targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)

      require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort)))
      require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr))
      require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections)))
      require.NoError(t, os.Setenv(EnvKeyTestLinger, "true"))

      go func() {
      require.NoError(t, run(t.Context(), logger))
      }()

      listener, err := net.Listen(TransportProtocol, targetAddr)
      require.NoError(t, err)

      passedConnections := atomic.Int32{}
      declinedConnections := atomic.Int32{}

      go func() {
      for {
      conn, err := listener.Accept()
      if err != nil {
      continue
      }

      runtime.SetFinalizer(conn, nil)

      receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2))
      require.Equal(t, receivedBytesNum, expectedBytesToRead)

      passedConnections.Add(1)
      }

      }()

      var firstConnection *net.Conn
      tryConnection := func() {
      conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second)
      require.NoError(t, err)

      <-time.After(time.Second / 2)

      if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil {
      declinedConnections.Add(1)

      return
      }

      if firstConnection == nil {
      firstConnection = &conn
      }

      runtime.SetFinalizer(conn, nil)
      }

      for range maxConnections + expectedDeclinedConnections {
      tryConnection()
      }

      if tcp, ok := (*firstConnection).(*net.TCPConn); ok {
      require.NoError(t, tcp.SetLinger(0))
      }

      require.NoError(t, (*firstConnection).Close())
      <-time.After(time.Second / 2)

      tryConnection()

      require.Eventually(t, func() bool {
      return int32(maxConnections)+1 == passedConnections.Load()
      }, time.Second, time.Millisecond*100)
      require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load())
      }

      Ну и на последок, атомики в таком стиле тут реализовать была идея не очень, ибо в итоге у тебя критической секцией стала не сама переменная с количеством коннектов, а сам атомик, иронично выходит конечно (между загрузкой и обновлением). Нужна семафора, еррор группа как раз реализует ее
      4 июн 2025 в 23:53 Изменено
    2. krisssss Автор темы
Top