From 68c1952b303f29319786a42c2b13ffd08eaffb7f Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Sat, 29 Nov 2025 15:02:28 +0800 Subject: [PATCH] refactor: refactor socks5 server --- src/go.mod | 3 + src/go.sum | 11 + src/internal/server/socks5/socks5.go | 508 ++++++++++++--------------- 3 files changed, 247 insertions(+), 275 deletions(-) diff --git a/src/go.mod b/src/go.mod index f5f3016..461a47c 100644 --- a/src/go.mod +++ b/src/go.mod @@ -9,6 +9,7 @@ require ( github.com/gonetx/ipset v0.1.0 github.com/google/gopacket v1.1.19 github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/luyuhuang/subsocks v0.5.0 github.com/mdlayher/netlink v1.7.2 golang.org/x/sys v0.30.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -20,6 +21,8 @@ require ( github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/stretchr/testify v1.11.1 // indirect + github.com/tg123/go-htpasswd v1.0.0 // indirect + golang.org/x/crypto v0.33.0 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sync v0.11.0 // indirect ) diff --git a/src/go.sum b/src/go.sum index a758cda..ba794ee 100644 --- a/src/go.sum +++ b/src/go.sum @@ -7,30 +7,40 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/florianl/go-nfqueue/v2 v2.0.2 h1:FL5lQTeetgpCvac1TRwSfgaXUn0YSO7WzGvWNIp3JPE= github.com/florianl/go-nfqueue/v2 v2.0.2/go.mod h1:VA09+iPOT43OMoCKNfXHyzujQUty2xmzyCRkBOlmabc= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gonetx/ipset v0.1.0 h1:LFkRdTbedg2UYXFN/2mOtgbvdWyo+OERrwVbtrPVuYY= github.com/gonetx/ipset v0.1.0/go.mod h1:AwNAf1Vtqg0cJ4bha4w1ROX5cO/8T50UYoegxM20AH8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/lithammer/dedent v1.1.0 h1:VNzHMVCBNG1j0fh3OrsFRkVUwStdDArbgBWoPAffktY= github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc= +github.com/luyuhuang/subsocks v0.5.0 h1:jOyQxU2Xw/7HFLQbd9YEGEnvDv59a7bOhk7ecxjq6TA= +github.com/luyuhuang/subsocks v0.5.0/go.mod h1:Rz6D6+j9D0BRZ33LivuE66gquUm2VBVtha8I8TlUIVM= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tg123/go-htpasswd v1.0.0 h1:Ze/pZsz73JiCwXIyJBPvNs75asKBgfodCf8iTEkgkXs= +github.com/tg123/go-htpasswd v1.0.0/go.mod h1:eQTgl67UrNKQvEPKrDLGBssjVwYQClFZjALVLhIv8C0= +golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -42,6 +52,7 @@ golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/src/internal/server/socks5/socks5.go b/src/internal/server/socks5/socks5.go index d560b2a..3e6de76 100644 --- a/src/internal/server/socks5/socks5.go +++ b/src/internal/server/socks5/socks5.go @@ -1,46 +1,26 @@ package socks5 import ( - "encoding/binary" + "bytes" "errors" "fmt" - "io" "log/slog" "net" - "strings" "syscall" "time" "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/luyuhuang/subsocks/socks" "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" ) -// SOCKS5 constants -const ( - socksVer5 = 0x05 - socksNoAuth = 0x00 - socksCmdConn = 0x01 - socksCmdUDP = 0x03 - - socksATYPv4 = 0x01 - socksATYDomain = 0x03 - socksATYPv6 = 0x04 -) - -var ( - ErrInvalidSocksVersion = errors.New("invalid socks version") - ErrInvalidSocksCmd = errors.New("invalid socks cmd") -) - -// Server is a minimal SOCKS5 server that delegates HTTP UA rewriting to Rewriter. type Server struct { base.Server listener net.Listener } -// New returns a new Server with given config, rewriter, and version string. func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { return &Server{ Server: base.Server{ @@ -58,7 +38,6 @@ func (s *Server) Close() (err error) { return } -// Start begins listening for SOCKS5 clients. func (s *Server) Start() (err error) { if s.listener, err = net.Listen("tcp", s.Cfg.ListenAddr); err != nil { return fmt.Errorf("net.Listen: %w", err) @@ -76,299 +55,278 @@ func (s *Server) Start() (err error) { slog.Error("s.listener.Accept", slog.Any("error", err)) continue } - slog.Debug("Accept connection", slog.String("addr", client.RemoteAddr().String())) go s.HandleClient(client) } }() return nil } -// handleClient performs SOCKS5 negotiation and dispatches TCP/UDP handling. -func (s *Server) HandleClient(client net.Conn) { - // Handshake (no auth) - if err := s.socks5Auth(client); err != nil { - _ = client.Close() +func (s *Server) HandleClient(conn net.Conn) { + defer func() { + _ = conn.Close() + }() + + srcAddr := conn.RemoteAddr().String() + + slog.Info("New socks5 connection", slog.String("srcAddr", srcAddr)) + + if err := s.handShake(conn); err != nil { + slog.Error("s.handShake", slog.String("srcAddr", srcAddr), slog.Any("error", err)) return } - destAddrPort, cmd, err := s.parseSocks5Request(client) + request, err := socks.ReadRequest(conn) if err != nil { - if cmd == socksCmdUDP { - // UDP Associate - s.handleUDPAssociate(client) - _ = client.Close() - return + slog.Error("socks.ReadRequest", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + return + } + + switch request.Cmd { + case socks.CmdConnect: + err = s.handleConnect(conn, request) + if err != nil { + err = fmt.Errorf("s.handleConnect: %w", err) } - slog.Debug("ParseSocks5Request failed", slog.String("src", client.RemoteAddr().String()), slog.String("dst", destAddrPort), slog.Any("error", err)) - _ = client.Close() - return + case socks.CmdBind: + err = s.handleBind(conn) + if err != nil { + err = fmt.Errorf("s.handleBind: %w", err) + } + case socks.CmdUDP: + err = s.handleUDPAssociate(conn) + if err != nil { + err = fmt.Errorf("s.handleUDPAssociate: %w", err) + } + default: + err = fmt.Errorf("socks5 unsupported command %d", request.Cmd) } - - // TCP CONNECT - target, err := s.socks5Connect(client, destAddrPort) if err != nil { - slog.Warn("s.socks5Connect", slog.String("addr", destAddrPort), slog.Any("error", err)) - _ = client.Close() + slog.Error("HandleClient", slog.String("srcAddr", srcAddr), slog.Any("error", err)) return } - s.ServeConnLink(&base.ConnLink{ - LConn: client, - RConn: target, - LAddr: client.RemoteAddr().String(), - RAddr: target.RemoteAddr().String(), - }) } -// socks5Auth performs a minimal "no-auth" negotiation. -func (s *Server) socks5Auth(client net.Conn) error { - buf := make([]byte, 256) - - // Read VER, NMETHODS - n, err := io.ReadFull(client, buf[:2]) - if n != 2 { - if errors.Is(err, io.EOF) { - slog.Warn("socks5Auth read EOF", slog.String("addr", client.RemoteAddr().String())) - } else { - slog.Error("socks5Auth read header", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) +func (s *Server) handShake(conn net.Conn) error { + methods, err := socks.ReadMethods(conn) + if err != nil { + return fmt.Errorf("socks.ReadMethods: %w", err) + } + method := socks.MethodNoAcceptable + for _, m := range methods { + if m == socks.MethodNoAuth { + method = m } - return fmt.Errorf("io.ReadFull reading header: %w", err) } - ver, nMethods := int(buf[0]), int(buf[1]) - if ver != socksVer5 { - slog.Error("socks5Auth invalid ver", slog.String("addr", client.RemoteAddr().String())) - return ErrInvalidSocksVersion - } - - // Read METHODS - n, err = io.ReadFull(client, buf[:nMethods]) - if n != nMethods { - slog.Error("socks5Auth read methods", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - return fmt.Errorf("io.ReadFull read methods: %w", err) - } - - // Reply: no-auth - n, err = client.Write([]byte{socksVer5, socksNoAuth}) - if n != 2 || err != nil { - slog.Error("socks5Auth write rsp", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - return fmt.Errorf("client.Write rsp: %w", err) + if err := socks.WriteMethod(socks.MethodNoAuth, conn); err != nil || method == socks.MethodNoAcceptable { + if err != nil { + return fmt.Errorf("socks.WriteMethod: %w", err) + } else { + return fmt.Errorf("socks5 methods is not acceptable") + } } return nil } -// parseSocks5Request reads a single SOCKS5 request. Returns dest, cmd, and error. -func (s *Server) parseSocks5Request(client net.Conn) (string, byte, error) { - buf := make([]byte, 256) +func (s *Server) handleConnect(src net.Conn, req *socks.Request) error { + srcAddr := src.RemoteAddr().String() + destAddr := req.Addr.String() - // VER, CMD, RSV, ATYP - if _, err := io.ReadFull(client, buf[:4]); err != nil { - return "", 0, fmt.Errorf("read header: %w", err) - } - ver, cmd, atyp := buf[0], buf[1], buf[3] - if ver != socksVer5 { - return "", cmd, ErrInvalidSocksVersion - } - - // UDP associate: let caller handle - if cmd == socksCmdUDP { - return "", socksCmdUDP, errors.New("UDP Associate") - } - if cmd != socksCmdConn { - return "", cmd, ErrInvalidSocksCmd - } - - var addr string - switch atyp { - case socksATYPv4: - if _, err := io.ReadFull(client, buf[:4]); err != nil { - return "", cmd, fmt.Errorf("invalid IPv4: %w", err) + dest, err := net.Dial("tcp", destAddr) + if err != nil { + if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil { + slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) } - addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]) - - case socksATYDomain: - if _, err := io.ReadFull(client, buf[:1]); err != nil { - return "", cmd, fmt.Errorf("invalid hostname(len): %w", err) - } - addrLen := int(buf[0]) - if _, err := io.ReadFull(client, buf[:addrLen]); err != nil { - return "", cmd, fmt.Errorf("invalid hostname: %w", err) - } - addr = string(buf[:addrLen]) - - case socksATYPv6: - return "", cmd, errors.New("IPv6: not supported yet") - - default: - return "", cmd, errors.New("invalid atyp") + return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr) } - if _, err := io.ReadFull(client, buf[:2]); err != nil { - return "", cmd, fmt.Errorf("read port: %w", err) + if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil { + _ = dest.Close() + return fmt.Errorf("socks.NewReply.Write: %w", err) } - port := binary.BigEndian.Uint16(buf[:2]) - return fmt.Sprintf("%s:%d", addr, port), cmd, nil + s.ServeConnLink(&base.ConnLink{ + LConn: src, + RConn: dest, + LAddr: srcAddr, + RAddr: destAddr, + }) + + return nil } -// socks5Connect dials the target and responds success to the client. -func (s *Server) socks5Connect(client net.Conn, destAddrPort string) (net.Conn, error) { - target, err := base.Connect(destAddrPort) +func (s *Server) handleBind(conn net.Conn) error { + srcAddr := conn.RemoteAddr().String() + listener, err := net.ListenTCP("tcp", nil) if err != nil { - // Reply failure - _, _ = client.Write([]byte{socksVer5, 0x01, 0x00, socksATYPv4, 0, 0, 0, 0, 0, 0}) - return nil, fmt.Errorf("dial target %s: %w", destAddrPort, err) + if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil { + slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + } + return fmt.Errorf("net.ListenTCP: %w", err) } - // Reply success (bind set to 0.0.0.0:0) - if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, 0, 0}); err != nil { - _ = target.Close() - return nil, err - } - return target, nil -} -// handleUDPAssociate handles a UDP ASSOCIATE request by creating a UDP relay socket. -// Only IPv4 and domain ATYP are supported (no IPv6). -func (s *Server) handleUDPAssociate(client net.Conn) { - udpServer, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + addr, _ := socks.NewAddrFromAddr(listener.Addr(), conn.LocalAddr()) + if err := socks.NewReply(socks.Succeeded, addr).Write(conn); err != nil { + _ = listener.Close() + return fmt.Errorf("socks.NewReply.Write: %w", err) + } + + newConn, err := listener.AcceptTCP() + _ = listener.Close() if err != nil { - slog.Error("net.ListenUDP failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - return + if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil { + slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + } + return fmt.Errorf("listener.AcceptTCP: %w", err) } defer func() { - if err := udpServer.Close(); err != nil { - slog.Warn("udpServer.Close", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) + _ = newConn.Close() + }() + + raddr, _ := socks.NewAddr(newConn.RemoteAddr().String()) + if err := socks.NewReply(socks.Succeeded, raddr).Write(conn); err != nil { + return fmt.Errorf("socks.NewReply.Write: %w", err) + } + + s.ServeConnLink(&base.ConnLink{ + LConn: conn, + RConn: newConn, + LAddr: srcAddr, + RAddr: newConn.RemoteAddr().String(), + }) + return nil +} + +func (s *Server) handleUDPAssociate(conn net.Conn) error { + srcAddr := conn.RemoteAddr().String() + + udp, err := net.ListenUDP("udp", nil) + if err != nil { + if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil { + slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + } + return fmt.Errorf("net.ListenUDP: %w", err) + } + + addr, _ := socks.NewAddrFromAddr(udp.LocalAddr(), conn.LocalAddr()) + if err := socks.NewReply(socks.Succeeded, addr).Write(conn); err != nil { + _ = udp.Close() + return fmt.Errorf("socks.NewReply.Write: %w", err) + } + + slog.Info("UDP associate established", slog.String("srcAddr", srcAddr), slog.String("udpAddr", udp.LocalAddr().String())) + + s.tunnelUDP(conn, udp) + return nil +} + +func (s *Server) tunnelUDP(conn net.Conn, udp *net.UDPConn) { + srcAddr := conn.RemoteAddr().String() + tcpRemote := conn.RemoteAddr().(*net.TCPAddr) + + var clientUDPAddr *net.UDPAddr + + done := make(chan struct{}) + + go func() { + defer func() { + _ = udp.Close() + }() + + b := make([]byte, 64*1024) + + for { + select { + case <-done: + return + default: + } + + _ = udp.SetReadDeadline(time.Now().Add(time.Second * 30)) + n, addr, err := udp.ReadFrom(b) + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + slog.Error("udp.ReadFrom", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + continue + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + continue + } + + isFromClient := udpAddr.IP.Equal(tcpRemote.IP) + if isFromClient { + clientUDPAddr = udpAddr + + dgram, err := socks.ReadUDPDatagram(bytes.NewReader(b[:n])) + if err != nil { + slog.Error("socks.ReadUDPDatagram error", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + continue + } + + destAddr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + if err != nil { + slog.Error("net.ResolveUDPAddr error", + slog.String("srcAddr", srcAddr), + slog.String("destAddr", dgram.Header.Addr.String()), + slog.Any("error", err)) + continue + } + + if _, err := udp.WriteTo(dgram.Data, destAddr); err != nil { + slog.Error("udp.WriteTo dest error", + slog.String("srcAddr", srcAddr), + slog.String("destAddr", destAddr.String()), + slog.Any("error", err)) + continue + } + + slog.Debug("UDP relay request", + slog.String("from", addr.String()), + slog.String("to", destAddr.String()), + slog.Int("bytes", len(dgram.Data))) + + } else { + if clientUDPAddr == nil { + continue + } + + saddr, _ := socks.NewAddr(addr.String()) + dgram := socks.NewUDPDatagram( + socks.NewUDPHeader(0, 0, saddr), b[:n]) + + var writer bytes.Buffer + if err := dgram.Write(&writer); err != nil { + slog.Debug("dgram.Write error", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + continue + } + + if _, err := udp.WriteTo(writer.Bytes(), clientUDPAddr); err != nil { + slog.Debug("udp.WriteTo client error", slog.String("srcAddr", srcAddr), slog.Any("error", err)) + continue + } + + slog.Debug("UDP relay response", + slog.String("from", addr.String()), + slog.String("to", clientUDPAddr.String()), + slog.Int("bytes", n)) + } } }() - _, portStr, _ := net.SplitHostPort(udpServer.LocalAddr().String()) - slog.Debug("net.SplitHostPort", slog.String("addr", client.RemoteAddr().String()), slog.String("port", portStr)) - - portInt, _ := net.LookupPort("udp", portStr) - portBytes := make([]byte, 2) - binary.BigEndian.PutUint16(portBytes, uint16(portInt)) - - // Reply with chosen UDP port (bind addr set to 0.0.0.0) - if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, portBytes[0], portBytes[1]}); err != nil { - slog.Error("client.Write rsp", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - return - } - - buf := make([]byte, 65535) - udpPortMap := make(map[string][]byte) - var clientAddr *net.UDPAddr - isDomain := false - + // tcp connection monitor + b := make([]byte, 1) for { - _ = udpServer.SetReadDeadline(time.Now().Add(10 * time.Second)) - n, fromAddr, err := udpServer.ReadFromUDP(buf) - if err != nil { - if strings.Contains(err.Error(), "i/o timeout") { - slog.Debug("ReadFromUDP timeout", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - if !isAlive(client) { - slog.Debug("client is not alive", slog.String("addr", client.RemoteAddr().String())) - break - } - } else { - slog.Error("udpServer.ReadFromUDP failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - } - continue - } - if clientAddr == nil { - clientAddr = fromAddr - } - - if clientAddr.IP.Equal(fromAddr.IP) && clientAddr.Port == fromAddr.Port { - // Packet from client -> forward to remote - atyp := buf[3] - var ( - targetAddr string - targetPort uint16 - payload []byte - header []byte - targetIP net.IP - ) - - switch atyp { - case socksATYPv4: - isDomain = false - targetAddr = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7]) - targetIP = net.ParseIP(targetAddr) - targetPort = binary.BigEndian.Uint16(buf[8:10]) - payload = buf[10:n] - header = buf[0:10] - - case socksATYDomain: - isDomain = true - addrLen := int(buf[4]) - targetAddr = string(buf[5 : 5+addrLen]) - targetIPAddr, err := net.ResolveIPAddr("ip", targetAddr) - if err != nil { - slog.Error("net.ResolveIPAddr", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - break - } - targetIP = targetIPAddr.IP - targetPort = binary.BigEndian.Uint16(buf[5+addrLen : 5+addrLen+2]) - payload = buf[5+addrLen+2 : n] - header = buf[0 : 5+addrLen+2] - - case socksATYPv6: - slog.Error("IPv6: not supported yet", slog.String("addr", client.RemoteAddr().String())) - return - - default: - slog.Error("invalid atyp", slog.String("addr", client.RemoteAddr().String())) - return - } - - remoteAddr := &net.UDPAddr{IP: targetIP, Port: int(targetPort)} - udpPortMap[remoteAddr.String()] = make([]byte, len(header)) - copy(udpPortMap[remoteAddr.String()], header) - - _ = udpServer.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if _, err = udpServer.WriteToUDP(payload, remoteAddr); err != nil { - slog.Debug("WriteToUDP to remote failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - continue - } - } else { - // Packet from remote -> forward to client (rebuild header) - header := udpPortMap[fromAddr.String()] - if header == nil { - slog.Error("udpPortMap invalid header", slog.String("addr", client.RemoteAddr().String())) - continue - } - // For domain ATYP, preserve original head section size - if isDomain { - header = header[0:4] - } - body := append(header, buf[:n]...) - if _, err = udpServer.WriteToUDP(body, clientAddr); err != nil { - slog.Debug("WriteToUDP to client failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err)) - continue - } + _ = conn.SetReadDeadline(time.Now().Add(time.Minute)) + if _, err := conn.Read(b); err != nil { + slog.Info("TCP connection closed, stopping UDP relay", slog.String("srcAddr", srcAddr), slog.String("udpAddr", udp.LocalAddr().String())) + close(done) + return } } } - -// isAlive checks if a connection is still alive using a short read deadline. -func isAlive(conn net.Conn) bool { - one := make([]byte, 1) - _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - _, err := conn.Read(one) - if err != nil { - switch { - case errors.Is(err, io.EOF): - slog.Debug("isAlive: EOF", slog.String("addr", conn.RemoteAddr().String())) - return false - case strings.Contains(err.Error(), "use of closed network connection"): - slog.Debug("isAlive: closed", slog.String("addr", conn.RemoteAddr().String())) - return false - case strings.Contains(err.Error(), "i/o timeout"): - slog.Debug("isAlive: timeout", slog.String("addr", conn.RemoteAddr().String())) - return true - default: - slog.Debug("isAlive: error", slog.String("addr", conn.RemoteAddr().String()), slog.Any("error", err)) - return false - } - } - _ = conn.SetReadDeadline(time.Time{}) - return true -}