refactor: refactor socks5 server

This commit is contained in:
SunBK201 2025-11-29 15:02:28 +08:00
parent a005803862
commit 68c1952b30
3 changed files with 247 additions and 275 deletions

View File

@ -9,6 +9,7 @@ require (
github.com/gonetx/ipset v0.1.0
github.com/google/gopacket v1.1.19
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/luyuhuang/subsocks v0.5.0
github.com/mdlayher/netlink v1.7.2
golang.org/x/sys v0.30.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
@ -20,6 +21,8 @@ require (
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/tg123/go-htpasswd v1.0.0 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sync v0.11.0 // indirect
)

View File

@ -7,30 +7,40 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/florianl/go-nfqueue/v2 v2.0.2 h1:FL5lQTeetgpCvac1TRwSfgaXUn0YSO7WzGvWNIp3JPE=
github.com/florianl/go-nfqueue/v2 v2.0.2/go.mod h1:VA09+iPOT43OMoCKNfXHyzujQUty2xmzyCRkBOlmabc=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gonetx/ipset v0.1.0 h1:LFkRdTbedg2UYXFN/2mOtgbvdWyo+OERrwVbtrPVuYY=
github.com/gonetx/ipset v0.1.0/go.mod h1:AwNAf1Vtqg0cJ4bha4w1ROX5cO/8T50UYoegxM20AH8=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/lithammer/dedent v1.1.0 h1:VNzHMVCBNG1j0fh3OrsFRkVUwStdDArbgBWoPAffktY=
github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc=
github.com/luyuhuang/subsocks v0.5.0 h1:jOyQxU2Xw/7HFLQbd9YEGEnvDv59a7bOhk7ecxjq6TA=
github.com/luyuhuang/subsocks v0.5.0/go.mod h1:Rz6D6+j9D0BRZ33LivuE66gquUm2VBVtha8I8TlUIVM=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tg123/go-htpasswd v1.0.0 h1:Ze/pZsz73JiCwXIyJBPvNs75asKBgfodCf8iTEkgkXs=
github.com/tg123/go-htpasswd v1.0.0/go.mod h1:eQTgl67UrNKQvEPKrDLGBssjVwYQClFZjALVLhIv8C0=
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@ -42,6 +52,7 @@ golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -1,46 +1,26 @@
package socks5
import (
"encoding/binary"
"bytes"
"errors"
"fmt"
"io"
"log/slog"
"net"
"strings"
"syscall"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/luyuhuang/subsocks/socks"
"github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/rewrite"
"github.com/sunbk201/ua3f/internal/server/base"
)
// SOCKS5 constants
const (
socksVer5 = 0x05
socksNoAuth = 0x00
socksCmdConn = 0x01
socksCmdUDP = 0x03
socksATYPv4 = 0x01
socksATYDomain = 0x03
socksATYPv6 = 0x04
)
var (
ErrInvalidSocksVersion = errors.New("invalid socks version")
ErrInvalidSocksCmd = errors.New("invalid socks cmd")
)
// Server is a minimal SOCKS5 server that delegates HTTP UA rewriting to Rewriter.
type Server struct {
base.Server
listener net.Listener
}
// New returns a new Server with given config, rewriter, and version string.
func New(cfg *config.Config, rw *rewrite.Rewriter) *Server {
return &Server{
Server: base.Server{
@ -58,7 +38,6 @@ func (s *Server) Close() (err error) {
return
}
// Start begins listening for SOCKS5 clients.
func (s *Server) Start() (err error) {
if s.listener, err = net.Listen("tcp", s.Cfg.ListenAddr); err != nil {
return fmt.Errorf("net.Listen: %w", err)
@ -76,299 +55,278 @@ func (s *Server) Start() (err error) {
slog.Error("s.listener.Accept", slog.Any("error", err))
continue
}
slog.Debug("Accept connection", slog.String("addr", client.RemoteAddr().String()))
go s.HandleClient(client)
}
}()
return nil
}
// handleClient performs SOCKS5 negotiation and dispatches TCP/UDP handling.
func (s *Server) HandleClient(client net.Conn) {
// Handshake (no auth)
if err := s.socks5Auth(client); err != nil {
_ = client.Close()
func (s *Server) HandleClient(conn net.Conn) {
defer func() {
_ = conn.Close()
}()
srcAddr := conn.RemoteAddr().String()
slog.Info("New socks5 connection", slog.String("srcAddr", srcAddr))
if err := s.handShake(conn); err != nil {
slog.Error("s.handShake", slog.String("srcAddr", srcAddr), slog.Any("error", err))
return
}
destAddrPort, cmd, err := s.parseSocks5Request(client)
request, err := socks.ReadRequest(conn)
if err != nil {
if cmd == socksCmdUDP {
// UDP Associate
s.handleUDPAssociate(client)
_ = client.Close()
return
slog.Error("socks.ReadRequest", slog.String("srcAddr", srcAddr), slog.Any("error", err))
return
}
switch request.Cmd {
case socks.CmdConnect:
err = s.handleConnect(conn, request)
if err != nil {
err = fmt.Errorf("s.handleConnect: %w", err)
}
slog.Debug("ParseSocks5Request failed", slog.String("src", client.RemoteAddr().String()), slog.String("dst", destAddrPort), slog.Any("error", err))
_ = client.Close()
return
case socks.CmdBind:
err = s.handleBind(conn)
if err != nil {
err = fmt.Errorf("s.handleBind: %w", err)
}
case socks.CmdUDP:
err = s.handleUDPAssociate(conn)
if err != nil {
err = fmt.Errorf("s.handleUDPAssociate: %w", err)
}
default:
err = fmt.Errorf("socks5 unsupported command %d", request.Cmd)
}
// TCP CONNECT
target, err := s.socks5Connect(client, destAddrPort)
if err != nil {
slog.Warn("s.socks5Connect", slog.String("addr", destAddrPort), slog.Any("error", err))
_ = client.Close()
slog.Error("HandleClient", slog.String("srcAddr", srcAddr), slog.Any("error", err))
return
}
s.ServeConnLink(&base.ConnLink{
LConn: client,
RConn: target,
LAddr: client.RemoteAddr().String(),
RAddr: target.RemoteAddr().String(),
})
}
// socks5Auth performs a minimal "no-auth" negotiation.
func (s *Server) socks5Auth(client net.Conn) error {
buf := make([]byte, 256)
// Read VER, NMETHODS
n, err := io.ReadFull(client, buf[:2])
if n != 2 {
if errors.Is(err, io.EOF) {
slog.Warn("socks5Auth read EOF", slog.String("addr", client.RemoteAddr().String()))
} else {
slog.Error("socks5Auth read header", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
func (s *Server) handShake(conn net.Conn) error {
methods, err := socks.ReadMethods(conn)
if err != nil {
return fmt.Errorf("socks.ReadMethods: %w", err)
}
method := socks.MethodNoAcceptable
for _, m := range methods {
if m == socks.MethodNoAuth {
method = m
}
return fmt.Errorf("io.ReadFull reading header: %w", err)
}
ver, nMethods := int(buf[0]), int(buf[1])
if ver != socksVer5 {
slog.Error("socks5Auth invalid ver", slog.String("addr", client.RemoteAddr().String()))
return ErrInvalidSocksVersion
}
// Read METHODS
n, err = io.ReadFull(client, buf[:nMethods])
if n != nMethods {
slog.Error("socks5Auth read methods", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
return fmt.Errorf("io.ReadFull read methods: %w", err)
}
// Reply: no-auth
n, err = client.Write([]byte{socksVer5, socksNoAuth})
if n != 2 || err != nil {
slog.Error("socks5Auth write rsp", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
return fmt.Errorf("client.Write rsp: %w", err)
if err := socks.WriteMethod(socks.MethodNoAuth, conn); err != nil || method == socks.MethodNoAcceptable {
if err != nil {
return fmt.Errorf("socks.WriteMethod: %w", err)
} else {
return fmt.Errorf("socks5 methods is not acceptable")
}
}
return nil
}
// parseSocks5Request reads a single SOCKS5 request. Returns dest, cmd, and error.
func (s *Server) parseSocks5Request(client net.Conn) (string, byte, error) {
buf := make([]byte, 256)
func (s *Server) handleConnect(src net.Conn, req *socks.Request) error {
srcAddr := src.RemoteAddr().String()
destAddr := req.Addr.String()
// VER, CMD, RSV, ATYP
if _, err := io.ReadFull(client, buf[:4]); err != nil {
return "", 0, fmt.Errorf("read header: %w", err)
}
ver, cmd, atyp := buf[0], buf[1], buf[3]
if ver != socksVer5 {
return "", cmd, ErrInvalidSocksVersion
}
// UDP associate: let caller handle
if cmd == socksCmdUDP {
return "", socksCmdUDP, errors.New("UDP Associate")
}
if cmd != socksCmdConn {
return "", cmd, ErrInvalidSocksCmd
}
var addr string
switch atyp {
case socksATYPv4:
if _, err := io.ReadFull(client, buf[:4]); err != nil {
return "", cmd, fmt.Errorf("invalid IPv4: %w", err)
dest, err := net.Dial("tcp", destAddr)
if err != nil {
if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
}
addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3])
case socksATYDomain:
if _, err := io.ReadFull(client, buf[:1]); err != nil {
return "", cmd, fmt.Errorf("invalid hostname(len): %w", err)
}
addrLen := int(buf[0])
if _, err := io.ReadFull(client, buf[:addrLen]); err != nil {
return "", cmd, fmt.Errorf("invalid hostname: %w", err)
}
addr = string(buf[:addrLen])
case socksATYPv6:
return "", cmd, errors.New("IPv6: not supported yet")
default:
return "", cmd, errors.New("invalid atyp")
return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr)
}
if _, err := io.ReadFull(client, buf[:2]); err != nil {
return "", cmd, fmt.Errorf("read port: %w", err)
if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil {
_ = dest.Close()
return fmt.Errorf("socks.NewReply.Write: %w", err)
}
port := binary.BigEndian.Uint16(buf[:2])
return fmt.Sprintf("%s:%d", addr, port), cmd, nil
s.ServeConnLink(&base.ConnLink{
LConn: src,
RConn: dest,
LAddr: srcAddr,
RAddr: destAddr,
})
return nil
}
// socks5Connect dials the target and responds success to the client.
func (s *Server) socks5Connect(client net.Conn, destAddrPort string) (net.Conn, error) {
target, err := base.Connect(destAddrPort)
func (s *Server) handleBind(conn net.Conn) error {
srcAddr := conn.RemoteAddr().String()
listener, err := net.ListenTCP("tcp", nil)
if err != nil {
// Reply failure
_, _ = client.Write([]byte{socksVer5, 0x01, 0x00, socksATYPv4, 0, 0, 0, 0, 0, 0})
return nil, fmt.Errorf("dial target %s: %w", destAddrPort, err)
if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
}
return fmt.Errorf("net.ListenTCP: %w", err)
}
// Reply success (bind set to 0.0.0.0:0)
if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, 0, 0}); err != nil {
_ = target.Close()
return nil, err
}
return target, nil
}
// handleUDPAssociate handles a UDP ASSOCIATE request by creating a UDP relay socket.
// Only IPv4 and domain ATYP are supported (no IPv6).
func (s *Server) handleUDPAssociate(client net.Conn) {
udpServer, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
addr, _ := socks.NewAddrFromAddr(listener.Addr(), conn.LocalAddr())
if err := socks.NewReply(socks.Succeeded, addr).Write(conn); err != nil {
_ = listener.Close()
return fmt.Errorf("socks.NewReply.Write: %w", err)
}
newConn, err := listener.AcceptTCP()
_ = listener.Close()
if err != nil {
slog.Error("net.ListenUDP failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
return
if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
}
return fmt.Errorf("listener.AcceptTCP: %w", err)
}
defer func() {
if err := udpServer.Close(); err != nil {
slog.Warn("udpServer.Close", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
_ = newConn.Close()
}()
raddr, _ := socks.NewAddr(newConn.RemoteAddr().String())
if err := socks.NewReply(socks.Succeeded, raddr).Write(conn); err != nil {
return fmt.Errorf("socks.NewReply.Write: %w", err)
}
s.ServeConnLink(&base.ConnLink{
LConn: conn,
RConn: newConn,
LAddr: srcAddr,
RAddr: newConn.RemoteAddr().String(),
})
return nil
}
func (s *Server) handleUDPAssociate(conn net.Conn) error {
srcAddr := conn.RemoteAddr().String()
udp, err := net.ListenUDP("udp", nil)
if err != nil {
if err := socks.NewReply(socks.Failure, nil).Write(conn); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
}
return fmt.Errorf("net.ListenUDP: %w", err)
}
addr, _ := socks.NewAddrFromAddr(udp.LocalAddr(), conn.LocalAddr())
if err := socks.NewReply(socks.Succeeded, addr).Write(conn); err != nil {
_ = udp.Close()
return fmt.Errorf("socks.NewReply.Write: %w", err)
}
slog.Info("UDP associate established", slog.String("srcAddr", srcAddr), slog.String("udpAddr", udp.LocalAddr().String()))
s.tunnelUDP(conn, udp)
return nil
}
func (s *Server) tunnelUDP(conn net.Conn, udp *net.UDPConn) {
srcAddr := conn.RemoteAddr().String()
tcpRemote := conn.RemoteAddr().(*net.TCPAddr)
var clientUDPAddr *net.UDPAddr
done := make(chan struct{})
go func() {
defer func() {
_ = udp.Close()
}()
b := make([]byte, 64*1024)
for {
select {
case <-done:
return
default:
}
_ = udp.SetReadDeadline(time.Now().Add(time.Second * 30))
n, addr, err := udp.ReadFrom(b)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
slog.Error("udp.ReadFrom", slog.String("srcAddr", srcAddr), slog.Any("error", err))
continue
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
continue
}
isFromClient := udpAddr.IP.Equal(tcpRemote.IP)
if isFromClient {
clientUDPAddr = udpAddr
dgram, err := socks.ReadUDPDatagram(bytes.NewReader(b[:n]))
if err != nil {
slog.Error("socks.ReadUDPDatagram error", slog.String("srcAddr", srcAddr), slog.Any("error", err))
continue
}
destAddr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
if err != nil {
slog.Error("net.ResolveUDPAddr error",
slog.String("srcAddr", srcAddr),
slog.String("destAddr", dgram.Header.Addr.String()),
slog.Any("error", err))
continue
}
if _, err := udp.WriteTo(dgram.Data, destAddr); err != nil {
slog.Error("udp.WriteTo dest error",
slog.String("srcAddr", srcAddr),
slog.String("destAddr", destAddr.String()),
slog.Any("error", err))
continue
}
slog.Debug("UDP relay request",
slog.String("from", addr.String()),
slog.String("to", destAddr.String()),
slog.Int("bytes", len(dgram.Data)))
} else {
if clientUDPAddr == nil {
continue
}
saddr, _ := socks.NewAddr(addr.String())
dgram := socks.NewUDPDatagram(
socks.NewUDPHeader(0, 0, saddr), b[:n])
var writer bytes.Buffer
if err := dgram.Write(&writer); err != nil {
slog.Debug("dgram.Write error", slog.String("srcAddr", srcAddr), slog.Any("error", err))
continue
}
if _, err := udp.WriteTo(writer.Bytes(), clientUDPAddr); err != nil {
slog.Debug("udp.WriteTo client error", slog.String("srcAddr", srcAddr), slog.Any("error", err))
continue
}
slog.Debug("UDP relay response",
slog.String("from", addr.String()),
slog.String("to", clientUDPAddr.String()),
slog.Int("bytes", n))
}
}
}()
_, portStr, _ := net.SplitHostPort(udpServer.LocalAddr().String())
slog.Debug("net.SplitHostPort", slog.String("addr", client.RemoteAddr().String()), slog.String("port", portStr))
portInt, _ := net.LookupPort("udp", portStr)
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(portInt))
// Reply with chosen UDP port (bind addr set to 0.0.0.0)
if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, portBytes[0], portBytes[1]}); err != nil {
slog.Error("client.Write rsp", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
return
}
buf := make([]byte, 65535)
udpPortMap := make(map[string][]byte)
var clientAddr *net.UDPAddr
isDomain := false
// tcp connection monitor
b := make([]byte, 1)
for {
_ = udpServer.SetReadDeadline(time.Now().Add(10 * time.Second))
n, fromAddr, err := udpServer.ReadFromUDP(buf)
if err != nil {
if strings.Contains(err.Error(), "i/o timeout") {
slog.Debug("ReadFromUDP timeout", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
if !isAlive(client) {
slog.Debug("client is not alive", slog.String("addr", client.RemoteAddr().String()))
break
}
} else {
slog.Error("udpServer.ReadFromUDP failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
}
continue
}
if clientAddr == nil {
clientAddr = fromAddr
}
if clientAddr.IP.Equal(fromAddr.IP) && clientAddr.Port == fromAddr.Port {
// Packet from client -> forward to remote
atyp := buf[3]
var (
targetAddr string
targetPort uint16
payload []byte
header []byte
targetIP net.IP
)
switch atyp {
case socksATYPv4:
isDomain = false
targetAddr = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7])
targetIP = net.ParseIP(targetAddr)
targetPort = binary.BigEndian.Uint16(buf[8:10])
payload = buf[10:n]
header = buf[0:10]
case socksATYDomain:
isDomain = true
addrLen := int(buf[4])
targetAddr = string(buf[5 : 5+addrLen])
targetIPAddr, err := net.ResolveIPAddr("ip", targetAddr)
if err != nil {
slog.Error("net.ResolveIPAddr", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
break
}
targetIP = targetIPAddr.IP
targetPort = binary.BigEndian.Uint16(buf[5+addrLen : 5+addrLen+2])
payload = buf[5+addrLen+2 : n]
header = buf[0 : 5+addrLen+2]
case socksATYPv6:
slog.Error("IPv6: not supported yet", slog.String("addr", client.RemoteAddr().String()))
return
default:
slog.Error("invalid atyp", slog.String("addr", client.RemoteAddr().String()))
return
}
remoteAddr := &net.UDPAddr{IP: targetIP, Port: int(targetPort)}
udpPortMap[remoteAddr.String()] = make([]byte, len(header))
copy(udpPortMap[remoteAddr.String()], header)
_ = udpServer.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err = udpServer.WriteToUDP(payload, remoteAddr); err != nil {
slog.Debug("WriteToUDP to remote failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
continue
}
} else {
// Packet from remote -> forward to client (rebuild header)
header := udpPortMap[fromAddr.String()]
if header == nil {
slog.Error("udpPortMap invalid header", slog.String("addr", client.RemoteAddr().String()))
continue
}
// For domain ATYP, preserve original head section size
if isDomain {
header = header[0:4]
}
body := append(header, buf[:n]...)
if _, err = udpServer.WriteToUDP(body, clientAddr); err != nil {
slog.Debug("WriteToUDP to client failed", slog.String("addr", client.RemoteAddr().String()), slog.Any("error", err))
continue
}
_ = conn.SetReadDeadline(time.Now().Add(time.Minute))
if _, err := conn.Read(b); err != nil {
slog.Info("TCP connection closed, stopping UDP relay", slog.String("srcAddr", srcAddr), slog.String("udpAddr", udp.LocalAddr().String()))
close(done)
return
}
}
}
// isAlive checks if a connection is still alive using a short read deadline.
func isAlive(conn net.Conn) bool {
one := make([]byte, 1)
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err := conn.Read(one)
if err != nil {
switch {
case errors.Is(err, io.EOF):
slog.Debug("isAlive: EOF", slog.String("addr", conn.RemoteAddr().String()))
return false
case strings.Contains(err.Error(), "use of closed network connection"):
slog.Debug("isAlive: closed", slog.String("addr", conn.RemoteAddr().String()))
return false
case strings.Contains(err.Error(), "i/o timeout"):
slog.Debug("isAlive: timeout", slog.String("addr", conn.RemoteAddr().String()))
return true
default:
slog.Debug("isAlive: error", slog.String("addr", conn.RemoteAddr().String()), slog.Any("error", err))
return false
}
}
_ = conn.SetReadDeadline(time.Time{})
return true
}