Загрузка...

Script Limit the number of connections to the port

Thread in Go created by krisssss Jun 4, 2025. 201 view

  1. krisssss
    krisssss Topic starter Jun 4, 2025 Banned 10,787 Dec 24, 2024
    Code

    package main

    import (
    "context"
    "errors"
    "flag"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer
    IsTestSuite bool
    TargetPort int
    ProxyPort int
    MaxConnections int
    }

    var (
    proxyPortFlag = flag.Int(
    "proxy-port", 0,
    "Порт ******-сервера для входящих соединений, обязательный",
    )
    targetPortFlag = flag.Int(
    "target-port", 443,
    "Порт локального VLESS-сервера (по умолчанию 443)",
    )
    maxConnsFlag = flag.Int(
    "limit", 0,
    "Максимальное число одновременных подключений (обязательный, >0)",
    )
    testLingerFlag = flag.Bool(
    "test-linger", false,
    "Устанавливать TCP Linger=0 для тестов",
    )
    )

    func setupConfig(logger *slog.Logger) (*Config, error) {
    if *proxyPortFlag <= 0 || *proxyPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-proxy-port должен быть в диапазоне 1-65535, указан: %d", *proxyPortFlag)
    }
    if *targetPortFlag <= 0 || *targetPortFlag > 65535 {
    flag.Usage()
    return nil, fmt.Errorf("-target-port должен быть в диапазоне 1-65535, указан: %d", *targetPortFlag)
    }
    if *maxConnsFlag <= 0 {
    flag.Usage()
    return nil, fmt.Errorf("-limit должен быть > 0, указан: %d", *maxConnsFlag)
    }

    return &Config{
    Logger: logger,
    Dialer: &net.Dialer{Timeout: 30 * time.Second},
    ProxyPort: *proxyPortFlag,
    TargetPort: *targetPortFlag,
    MaxConnections: *maxConnsFlag,
    IsTestSuite: *testLingerFlag,
    }, nil
    }

    func main() {
    flag.Usage = func() {
    fmt.Fprintf(os.Stderr, "Использование: %s [OPTIONS]\n", os.Args[0])
    flag.PrintDefaults()
    }
    flag.Parse()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
    defer stop()

    if err := run(ctx, logger); err != nil && !errors.Is(err, context.Canceled) {
    logger.Warn("application error", slog.String("error", err.Error()))
    }
    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    cfg, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    listenAddr := fmt.Sprintf(":%d", cfg.ProxyPort)
    forwardAddr := fmt.Sprintf("127.0.0.1:%d", cfg.TargetPort)

    g, ctx := errgroup.WithContext(ctx)
    g.SetLimit(cfg.MaxConnections)

    listener, err := net.Listen(TransportProtocol, listenAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }
    defer listener.Close()

    logger.Info("proxy ready",
    slog.String("listen", listenAddr),
    slog.String("forwardTo", forwardAddr),
    slog.Int("maxConns", cfg.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
    break
    }
    logger.Warn("accept error", slog.String("error", err.Error()))
    continue
    }

    if ok := g.TryGo(func() error {
    defer conn.Close()
    return proxy(ctx, cfg, conn, forwardAddr)
    }); !ok {
    // сброс без логирования при превышении лимита
    if tcpConn, ok := conn.(*net.TCPConn); ok {
    tcpConn.SetLinger(0)
    tcpConn.SetKeepAlive(false)
    }
    conn.Close()
    continue
    }
    }

    return g.Wait()
    }

    func proxy(ctx context.Context, cfg *Config, client net.Conn, forwardAddr string) error {
    target, err := cfg.Dialer.DialContext(ctx, TransportProtocol, forwardAddr)
    if err != nil {
    return fmt.Errorf("dial target: %w", err)
    }
    defer target.Close()

    go io.Copy(target, client)
    _, err = io.Copy(client, target)
    return err
    }




     
  2. yoona
    yoona Jun 4, 2025 58 Sep 10, 2017
    Дело было вечером, делать было нечего

    В идеале отдавать ConnectionRefused, но мне стало лень возиться с этим, и хотелось сохранить изначальную логику редиректа

    Немного косметических процедур и рефакторинга, но разделять по файлам уж было лень в силу того, что сюда структуру потом не залить :(

    В тесте используется финалайзер, что потом может сказаться негативно на большем количестве тестов, ибо не отдает порт системе до закрытия теста. Можно переписать что бы оно чисто в массиве их держало тогда, но мне было лень.

    Нужно для того, что бы GC сам не закрывал порт, ибо тогда получился бы флапающий тест
    Code
    package main

    import (
    "context"
    "errors"
    "fmt"
    "io"
    "log/slog"
    "net"
    "os"
    "os/signal"
    "strconv"
    "syscall"
    "time"

    "golang.org/x/sync/errgroup"
    )

    const (
    envPrefix = "PROXY_APP"
    EnvKeyTargetAddr = envPrefix + "TARGET_ADDR"
    EnvKeyMaxConnections = envPrefix + "MAX_CONNECTIONS"
    EnvKeyProxyAddr = envPrefix + "PROXY_ADDR"
    EnvKeyTestLinger = envPrefix + "TEST_LINGER"

    TransportProtocol = "tcp"
    )

    type Config struct {
    Logger *slog.Logger
    Dialer *net.Dialer

    IsTestSuite bool
    TargetAddr string
    ProxyAddr string
    MaxConnections int
    }

    func setupConfig(logger *slog.Logger) (*Config, error) {
    maxConnectionsRaw := os.Getenv(EnvKeyMaxConnections)
    maxConnections, err := strconv.Atoi(maxConnectionsRaw)
    if err != nil {
    return nil, fmt.Errorf("atoi max connections: %w", err)
    }

    config := &Config{
    Logger: logger,
    Dialer: &net.Dialer{},
    TargetAddr: os.Getenv(EnvKeyTargetAddr),
    ProxyAddr: os.Getenv(EnvKeyProxyAddr),
    IsTestSuite: os.Getenv(EnvKeyTestLinger) != "",
    MaxConnections: maxConnections,
    }

    if config.TargetAddr == "" || config.ProxyAddr == "" || config.MaxConnections == 0 {
    return nil, fmt.Errorf("one of the required arguments missing") // Мне лень их было переписывать, а лучше вообще валидатором пройтись, сами уже
    }

    return config, nil
    }

    func main() {
    errGroup, ctx := errgroup.WithContext(context.Background())

    ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
    defer cancel()

    logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))

    errGroup.Go(func() error {
    if err := run(ctx, logger); err != nil {
    return fmt.Errorf("run: %w", err)
    }

    return nil
    })

    if err := errGroup.Wait(); err != nil {
    logger.Warn("failed to run application", slog.String("err", err.Error()))
    }

    <-ctx.Done()

    // Лучше заменить на какую нибудь имплементацию Closer'a, но мне лень стало ее сюда тащить
    // И потом контексты поменять на более 'чистые', что бы красивше завершать через клозеры, а не насильно контекстом
    <-time.After(time.Second * 5)

    logger.Info("application shutdown")
    }

    func run(ctx context.Context, logger *slog.Logger) error {
    config, err := setupConfig(logger)
    if err != nil {
    return fmt.Errorf("setup config: %w", err)
    }

    errGroup := errgroup.Group{}
    errGroup.SetLimit(config.MaxConnections)

    listenConfig := net.ListenConfig{}
    listener, err := listenConfig.Listen(context.Background(), TransportProtocol, config.ProxyAddr)
    if err != nil {
    return fmt.Errorf("listen: %w", err)
    }

    defer listener.Close()

    config.Logger.Info(
    "ready to accept connections",
    slog.String("proxyAddr", config.ProxyAddr),
    slog.String("targetAddr", config.TargetAddr),
    slog.Int("limit", config.MaxConnections),
    )

    for {
    conn, err := listener.Accept()
    if err != nil {
    config.Logger.Warn("failed to accept connection", slog.String("err", err.Error()))
    if errors.Is(err, net.ErrClosed) {
    break
    }

    continue
    }

    if config.IsTestSuite {
    if tcp, ok := conn.(*net.TCPConn); ok {
    tcp.SetLinger(0)
    }
    }

    config.Logger.Debug("accepted connection", slog.String("addr", conn.RemoteAddr().String()))

    if ok := errGroup.TryGo(func() error {
    if err := handleConn(ctx, config, conn); err != nil {
    config.Logger.Warn("handle connection", slog.String("err", err.Error()))
    }

    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Debug("closing connection", slog.String("addr", conn.RemoteAddr().String()))

    return nil
    }); !ok {
    if err := conn.Close(); err != nil {
    config.Logger.Warn("failed to close listener", slog.String("err", err.Error()))
    }

    config.Logger.Warn("exceeded limit", slog.String("remoteAddr", conn.RemoteAddr().String()), slog.Int("limit", config.MaxConnections))
    }

    }

    return nil
    }

    func handleConn(ctx context.Context, config *Config, client net.Conn) error {
    target, err := config.Dialer.DialContext(ctx, TransportProtocol, config.TargetAddr)
    if err != nil {
    return fmt.Errorf("dial: %w", err)
    }

    defer target.Close()

    config.Logger.Debug("starting connection redirect", slog.String("addr", client.RemoteAddr().String()))

    go io.Copy(target, client)
    io.Copy(client, target)

    return nil
    }
    Code
    package main

    import (
    "fmt"
    "log/slog"
    "net"
    "os"
    "runtime"
    "strconv"
    "sync/atomic"
    "testing"
    "time"

    "github.com/stretchr/testify/require"
    )

    func TestApp(t *testing.T) {
    const (
    targetPort = 40066
    proxyPort = 40077
    maxConnections = 3

    expectedDeclinedConnections = 10
    expectedBytesToRead = 8
    )

    logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
    Level: slog.LevelDebug,
    }))

    targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)

    require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort)))
    require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr))
    require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections)))
    require.NoError(t, os.Setenv(EnvKeyTestLinger, "true"))

    go func() {
    require.NoError(t, run(t.Context(), logger))
    }()

    listener, err := net.Listen(TransportProtocol, targetAddr)
    require.NoError(t, err)

    passedConnections := atomic.Int32{}
    declinedConnections := atomic.Int32{}

    go func() {
    for {
    conn, err := listener.Accept()
    if err != nil {
    continue
    }

    runtime.SetFinalizer(conn, nil)

    receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2))
    require.Equal(t, receivedBytesNum, expectedBytesToRead)

    passedConnections.Add(1)
    }

    }()

    for range maxConnections + expectedDeclinedConnections {
    conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second)
    require.NoError(t, err)

    <-time.After(time.Second / 2)

    if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil {
    declinedConnections.Add(1)

    continue
    }

    runtime.SetFinalizer(conn, nil)
    }

    require.Equal(t, int32(maxConnections), passedConnections.Load())
    require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load())
    }
     
    1. View previous comments (1)
    2. krisssss Topic starter
    3. yoona
      krisssss, и что получилось?
    4. krisssss Topic starter
Loading...
Top