package main import ( "context" "errors" "flag" "fmt" "io" "log/slog" "net" "os" "os/signal" "syscall" "time" "golang.org/x/sync/errgroup" ) const ( TransportProtocol = "tcp" ) type Config struct { Logger *slog.Logger Dialer *net.Dialer IsTestSuite bool TargetPort int ProxyPort int MaxConnections int } var ( proxyPortFlag = flag.Int( "proxy-port", 0, "Порт ******-сервера для входящих соединений, обязательный", ) targetPortFlag = flag.Int( "target-port", 443, "Порт локального VLESS-сервера (по умолчанию 443)", ) maxConnsFlag = flag.Int( "limit", 0, "Максимальное число одновременных подключений (обязательный, >0)", ) testLingerFlag = flag.Bool( "test-linger", false, "Устанавливать TCP Linger=0 для тестов", ) ) func setupConfig(logger *slog.Logger) (*Config, error) { if *proxyPortFlag <= 0 || *proxyPortFlag > 65535 { flag.Usage() return nil, fmt.Errorf("-proxy-port должен быть в диапазоне 1-65535, указан: %d", *proxyPortFlag) } if *targetPortFlag <= 0 || *targetPortFlag > 65535 { flag.Usage() return nil, fmt.Errorf("-target-port должен быть в диапазоне 1-65535, указан: %d", *targetPortFlag) } if *maxConnsFlag <= 0 { flag.Usage() return nil, fmt.Errorf("-limit должен быть > 0, указан: %d", *maxConnsFlag) } return &Config{ Logger: logger, Dialer: &net.Dialer{Timeout: 30 * time.Second}, ProxyPort: *proxyPortFlag, TargetPort: *targetPortFlag, MaxConnections: *maxConnsFlag, IsTestSuite: *testLingerFlag, }, nil } func main() { flag.Usage = func() { fmt.Fprintf(os.Stderr, "Использование: %s [OPTIONS]\n", os.Args[0]) flag.PrintDefaults() } flag.Parse() logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() if err := run(ctx, logger); err != nil && !errors.Is(err, context.Canceled) { logger.Warn("application error", slog.String("error", err.Error())) } logger.Info("application shutdown") } func run(ctx context.Context, logger *slog.Logger) error { cfg, err := setupConfig(logger) if err != nil { return fmt.Errorf("setup config: %w", err) } listenAddr := fmt.Sprintf(":%d", cfg.ProxyPort) forwardAddr := fmt.Sprintf("127.0.0.1:%d", cfg.TargetPort) g, ctx := errgroup.WithContext(ctx) g.SetLimit(cfg.MaxConnections) listener, err := net.Listen(TransportProtocol, listenAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() logger.Info("proxy ready", slog.String("listen", listenAddr), slog.String("forwardTo", forwardAddr), slog.Int("maxConns", cfg.MaxConnections), ) for { conn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { break } logger.Warn("accept error", slog.String("error", err.Error())) continue } if ok := g.TryGo(func() error { defer conn.Close() return proxy(ctx, cfg, conn, forwardAddr) }); !ok { // сброс без логирования при превышении лимита if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetLinger(0) tcpConn.SetKeepAlive(false) } conn.Close() continue } } return g.Wait() } func proxy(ctx context.Context, cfg *Config, client net.Conn, forwardAddr string) error { target, err := cfg.Dialer.DialContext(ctx, TransportProtocol, forwardAddr) if err != nil { return fmt.Errorf("dial target: %w", err) } defer target.Close() go io.Copy(target, client) _, err = io.Copy(client, target) return err } Код package main import ( "context" "errors" "flag" "fmt" "io" "log/slog" "net" "os" "os/signal" "syscall" "time" "golang.org/x/sync/errgroup" ) const ( TransportProtocol = "tcp" ) type Config struct { Logger *slog.Logger Dialer *net.Dialer IsTestSuite bool TargetPort int ProxyPort int MaxConnections int } var ( proxyPortFlag = flag.Int( "proxy-port", 0, "Порт ******-сервера для входящих соединений, обязательный", ) targetPortFlag = flag.Int( "target-port", 443, "Порт локального VLESS-сервера (по умолчанию 443)", ) maxConnsFlag = flag.Int( "limit", 0, "Максимальное число одновременных подключений (обязательный, >0)", ) testLingerFlag = flag.Bool( "test-linger", false, "Устанавливать TCP Linger=0 для тестов", ) ) func setupConfig(logger *slog.Logger) (*Config, error) { if *proxyPortFlag <= 0 || *proxyPortFlag > 65535 { flag.Usage() return nil, fmt.Errorf("-proxy-port должен быть в диапазоне 1-65535, указан: %d", *proxyPortFlag) } if *targetPortFlag <= 0 || *targetPortFlag > 65535 { flag.Usage() return nil, fmt.Errorf("-target-port должен быть в диапазоне 1-65535, указан: %d", *targetPortFlag) } if *maxConnsFlag <= 0 { flag.Usage() return nil, fmt.Errorf("-limit должен быть > 0, указан: %d", *maxConnsFlag) } return &Config{ Logger: logger, Dialer: &net.Dialer{Timeout: 30 * time.Second}, ProxyPort: *proxyPortFlag, TargetPort: *targetPortFlag, MaxConnections: *maxConnsFlag, IsTestSuite: *testLingerFlag, }, nil } func main() { flag.Usage = func() { fmt.Fprintf(os.Stderr, "Использование: %s [OPTIONS]\n", os.Args[0]) flag.PrintDefaults() } flag.Parse() logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() if err := run(ctx, logger); err != nil && !errors.Is(err, context.Canceled) { logger.Warn("application error", slog.String("error", err.Error())) } logger.Info("application shutdown") } func run(ctx context.Context, logger *slog.Logger) error { cfg, err := setupConfig(logger) if err != nil { return fmt.Errorf("setup config: %w", err) } listenAddr := fmt.Sprintf(":%d", cfg.ProxyPort) forwardAddr := fmt.Sprintf("127.0.0.1:%d", cfg.TargetPort) g, ctx := errgroup.WithContext(ctx) g.SetLimit(cfg.MaxConnections) listener, err := net.Listen(TransportProtocol, listenAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() logger.Info("proxy ready", slog.String("listen", listenAddr), slog.String("forwardTo", forwardAddr), slog.Int("maxConns", cfg.MaxConnections), ) for { conn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { break } logger.Warn("accept error", slog.String("error", err.Error())) continue } if ok := g.TryGo(func() error { defer conn.Close() return proxy(ctx, cfg, conn, forwardAddr) }); !ok { // сброс без логирования при превышении лимита if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetLinger(0) tcpConn.SetKeepAlive(false) } conn.Close() continue } } return g.Wait() } func proxy(ctx context.Context, cfg *Config, client net.Conn, forwardAddr string) error { target, err := cfg.Dialer.DialContext(ctx, TransportProtocol, forwardAddr) if err != nil { return fmt.Errorf("dial target: %w", err) } defer target.Close() go io.Copy(target, client) _, err = io.Copy(client, target) return err }
Дело было вечером, делать было нечего В идеале отдавать ConnectionRefused, но мне стало лень возиться с этим, и хотелось сохранить изначальную логику редиректа Немного косметических процедур и рефакторинга, но разделять по файлам уж было лень в силу того, что сюда структуру потом не залить :( В тесте используется финалайзер, что потом может сказаться негативно на большем количестве тестов, ибо не отдает порт системе до закрытия теста. Можно переписать что бы оно чисто в массиве их держало тогда, но мне было лень. Нужно для того, что бы GC сам не закрывал порт, ибо тогда получился бы флапающий тест main.go package main import ( "context" "errors" "fmt" "io" "log/slog" "net" "os" "os/signal" "strconv" "syscall" "time" "golang.org/x/sync/errgroup" ) const ( envPrefix = "PROXY_APP" EnvKeyTargetAddr = envPrefix + "TARGET_ADDR" EnvKeyMaxConnections = envPrefix + "MAX_CONNECTIONS" EnvKeyProxyAddr = envPrefix + "PROXY_ADDR" EnvKeyTestLinger = envPrefix + "TEST_LINGER" TransportProtocol = "tcp" ) type Config struct { Logger *slog.Logger Dialer *net.Dialer IsTestSuite bool TargetAddr string ProxyAddr string MaxConnections int } func setupConfig(logger *slog.Logger) (*Config, error) { maxConnectionsRaw := os.Getenv(EnvKeyMaxConnections) maxConnections, err := strconv.Atoi(maxConnectionsRaw) if err != nil { return nil, fmt.Errorf("atoi max connections: %w", err) } config := &Config{ Logger: logger, Dialer: &net.Dialer{}, TargetAddr: os.Getenv(EnvKeyTargetAddr), ProxyAddr: os.Getenv(EnvKeyProxyAddr), IsTestSuite: os.Getenv(EnvKeyTestLinger) != "", MaxConnections: maxConnections, } if config.TargetAddr == "" || config.ProxyAddr == "" || config.MaxConnections == 0 { return nil, fmt.Errorf("one of the required arguments missing") // Мне лень их было переписывать, а лучше вообще валидатором пройтись, сами уже } return config, nil } func main() { errGroup, ctx := errgroup.WithContext(context.Background()) ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer cancel() logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) errGroup.Go(func() error { if err := run(ctx, logger); err != nil { return fmt.Errorf("run: %w", err) } return nil }) if err := errGroup.Wait(); err != nil { logger.Warn("failed to run application", slog.String("err", err.Error())) } <-ctx.Done() // Лучше заменить на какую нибудь имплементацию Closer'a, но мне лень стало ее сюда тащить // И потом контексты поменять на более 'чистые', что бы красивше завершать через клозеры, а не насильно контекстом <-time.After(time.Second * 5) logger.Info("application shutdown") } func run(ctx context.Context, logger *slog.Logger) error { config, err := setupConfig(logger) if err != nil { return fmt.Errorf("setup config: %w", err) } errGroup := errgroup.Group{} errGroup.SetLimit(config.MaxConnections) listenConfig := net.ListenConfig{} listener, err := listenConfig.Listen(context.Background(), TransportProtocol, config.ProxyAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() config.Logger.Info( "ready to accept connections", slog.String("proxyAddr", config.ProxyAddr), slog.String("targetAddr", config.TargetAddr), slog.Int("limit", config.MaxConnections), ) for { conn, err := listener.Accept() if err != nil { config.Logger.Warn("failed to accept connection", slog.String("err", err.Error())) if errors.Is(err, net.ErrClosed) { break } continue } if config.IsTestSuite { if tcp, ok := conn.(*net.TCPConn); ok { tcp.SetLinger(0) } } config.Logger.Debug("accepted connection", slog.String("addr", conn.RemoteAddr().String())) if ok := errGroup.TryGo(func() error { if err := handleConn(ctx, config, conn); err != nil { config.Logger.Warn("handle connection", slog.String("err", err.Error())) } if err := conn.Close(); err != nil { config.Logger.Warn("failed to close listener", slog.String("err", err.Error())) } config.Logger.Debug("closing connection", slog.String("addr", conn.RemoteAddr().String())) return nil }); !ok { if err := conn.Close(); err != nil { config.Logger.Warn("failed to close listener", slog.String("err", err.Error())) } config.Logger.Warn("exceeded limit", slog.String("remoteAddr", conn.RemoteAddr().String()), slog.Int("limit", config.MaxConnections)) } } return nil } func handleConn(ctx context.Context, config *Config, client net.Conn) error { target, err := config.Dialer.DialContext(ctx, TransportProtocol, config.TargetAddr) if err != nil { return fmt.Errorf("dial: %w", err) } defer target.Close() config.Logger.Debug("starting connection redirect", slog.String("addr", client.RemoteAddr().String())) go io.Copy(target, client) io.Copy(client, target) return nil } Код package main import ( "context" "errors" "fmt" "io" "log/slog" "net" "os" "os/signal" "strconv" "syscall" "time" "golang.org/x/sync/errgroup" ) const ( envPrefix = "PROXY_APP" EnvKeyTargetAddr = envPrefix + "TARGET_ADDR" EnvKeyMaxConnections = envPrefix + "MAX_CONNECTIONS" EnvKeyProxyAddr = envPrefix + "PROXY_ADDR" EnvKeyTestLinger = envPrefix + "TEST_LINGER" TransportProtocol = "tcp" ) type Config struct { Logger *slog.Logger Dialer *net.Dialer IsTestSuite bool TargetAddr string ProxyAddr string MaxConnections int } func setupConfig(logger *slog.Logger) (*Config, error) { maxConnectionsRaw := os.Getenv(EnvKeyMaxConnections) maxConnections, err := strconv.Atoi(maxConnectionsRaw) if err != nil { return nil, fmt.Errorf("atoi max connections: %w", err) } config := &Config{ Logger: logger, Dialer: &net.Dialer{}, TargetAddr: os.Getenv(EnvKeyTargetAddr), ProxyAddr: os.Getenv(EnvKeyProxyAddr), IsTestSuite: os.Getenv(EnvKeyTestLinger) != "", MaxConnections: maxConnections, } if config.TargetAddr == "" || config.ProxyAddr == "" || config.MaxConnections == 0 { return nil, fmt.Errorf("one of the required arguments missing") // Мне лень их было переписывать, а лучше вообще валидатором пройтись, сами уже } return config, nil } func main() { errGroup, ctx := errgroup.WithContext(context.Background()) ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer cancel() logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) errGroup.Go(func() error { if err := run(ctx, logger); err != nil { return fmt.Errorf("run: %w", err) } return nil }) if err := errGroup.Wait(); err != nil { logger.Warn("failed to run application", slog.String("err", err.Error())) } <-ctx.Done() // Лучше заменить на какую нибудь имплементацию Closer'a, но мне лень стало ее сюда тащить // И потом контексты поменять на более 'чистые', что бы красивше завершать через клозеры, а не насильно контекстом <-time.After(time.Second * 5) logger.Info("application shutdown") } func run(ctx context.Context, logger *slog.Logger) error { config, err := setupConfig(logger) if err != nil { return fmt.Errorf("setup config: %w", err) } errGroup := errgroup.Group{} errGroup.SetLimit(config.MaxConnections) listenConfig := net.ListenConfig{} listener, err := listenConfig.Listen(context.Background(), TransportProtocol, config.ProxyAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() config.Logger.Info( "ready to accept connections", slog.String("proxyAddr", config.ProxyAddr), slog.String("targetAddr", config.TargetAddr), slog.Int("limit", config.MaxConnections), ) for { conn, err := listener.Accept() if err != nil { config.Logger.Warn("failed to accept connection", slog.String("err", err.Error())) if errors.Is(err, net.ErrClosed) { break } continue } if config.IsTestSuite { if tcp, ok := conn.(*net.TCPConn); ok { tcp.SetLinger(0) } } config.Logger.Debug("accepted connection", slog.String("addr", conn.RemoteAddr().String())) if ok := errGroup.TryGo(func() error { if err := handleConn(ctx, config, conn); err != nil { config.Logger.Warn("handle connection", slog.String("err", err.Error())) } if err := conn.Close(); err != nil { config.Logger.Warn("failed to close listener", slog.String("err", err.Error())) } config.Logger.Debug("closing connection", slog.String("addr", conn.RemoteAddr().String())) return nil }); !ok { if err := conn.Close(); err != nil { config.Logger.Warn("failed to close listener", slog.String("err", err.Error())) } config.Logger.Warn("exceeded limit", slog.String("remoteAddr", conn.RemoteAddr().String()), slog.Int("limit", config.MaxConnections)) } } return nil } func handleConn(ctx context.Context, config *Config, client net.Conn) error { target, err := config.Dialer.DialContext(ctx, TransportProtocol, config.TargetAddr) if err != nil { return fmt.Errorf("dial: %w", err) } defer target.Close() config.Logger.Debug("starting connection redirect", slog.String("addr", client.RemoteAddr().String())) go io.Copy(target, client) io.Copy(client, target) return nil } main_test.go package main import ( "fmt" "log/slog" "net" "os" "runtime" "strconv" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" ) func TestApp(t *testing.T) { const ( targetPort = 40066 proxyPort = 40077 maxConnections = 3 expectedDeclinedConnections = 10 expectedBytesToRead = 8 ) logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelDebug, })) targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort) require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort))) require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr)) require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections))) require.NoError(t, os.Setenv(EnvKeyTestLinger, "true")) go func() { require.NoError(t, run(t.Context(), logger)) }() listener, err := net.Listen(TransportProtocol, targetAddr) require.NoError(t, err) passedConnections := atomic.Int32{} declinedConnections := atomic.Int32{} go func() { for { conn, err := listener.Accept() if err != nil { continue } runtime.SetFinalizer(conn, nil) receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2)) require.Equal(t, receivedBytesNum, expectedBytesToRead) passedConnections.Add(1) } }() for range maxConnections + expectedDeclinedConnections { conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second) require.NoError(t, err) <-time.After(time.Second / 2) if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil { declinedConnections.Add(1) continue } runtime.SetFinalizer(conn, nil) } require.Equal(t, int32(maxConnections), passedConnections.Load()) require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load()) } Код package main import ( "fmt" "log/slog" "net" "os" "runtime" "strconv" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" ) func TestApp(t *testing.T) { const ( targetPort = 40066 proxyPort = 40077 maxConnections = 3 expectedDeclinedConnections = 10 expectedBytesToRead = 8 ) logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelDebug, })) targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort) require.NoError(t, os.Setenv(EnvKeyProxyAddr, fmt.Sprintf(":%d", proxyPort))) require.NoError(t, os.Setenv(EnvKeyTargetAddr, targetAddr)) require.NoError(t, os.Setenv(EnvKeyMaxConnections, strconv.Itoa(maxConnections))) require.NoError(t, os.Setenv(EnvKeyTestLinger, "true")) go func() { require.NoError(t, run(t.Context(), logger)) }() listener, err := net.Listen(TransportProtocol, targetAddr) require.NoError(t, err) passedConnections := atomic.Int32{} declinedConnections := atomic.Int32{} go func() { for { conn, err := listener.Accept() if err != nil { continue } runtime.SetFinalizer(conn, nil) receivedBytesNum, err := conn.Read(make([]byte, expectedBytesToRead*2)) require.Equal(t, receivedBytesNum, expectedBytesToRead) passedConnections.Add(1) } }() for range maxConnections + expectedDeclinedConnections { conn, err := net.DialTimeout(TransportProtocol, fmt.Sprintf("localhost:%d", proxyPort), time.Second) require.NoError(t, err) <-time.After(time.Second / 2) if _, err = conn.Write(make([]byte, expectedBytesToRead)); err != nil { declinedConnections.Add(1) continue } runtime.SetFinalizer(conn, nil) } require.Equal(t, int32(maxConnections), passedConnections.Load()) require.Equal(t, int32(expectedDeclinedConnections), declinedConnections.Load()) }