From 317dba5e78e93d242183dfbb58af3a773cf2e842 Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Sun, 7 Dec 2025 15:33:53 +0800 Subject: [PATCH] feat: add so_mark for socks5 and http --- src/internal/server/base/tcp.go | 18 +++------------ src/internal/server/base/tcp_linux.go | 22 ++++++++++++++----- src/internal/server/http/http.go | 5 ++++- .../server/redirect/redirect_linux.go | 4 ++-- src/internal/server/socks5/socks5.go | 7 ++++-- src/internal/server/tproxy/tproxy_linux.go | 4 ++-- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/internal/server/base/tcp.go b/src/internal/server/base/tcp.go index 06806df..d36d872 100644 --- a/src/internal/server/base/tcp.go +++ b/src/internal/server/base/tcp.go @@ -1,27 +1,15 @@ +//go:build !linux + package base import ( - "errors" "fmt" "net" ) -func Connect(addr string) (target net.Conn, err error) { +func Connect(addr string, mark int) (target net.Conn, err error) { if target, err = net.Dial("tcp", addr); err != nil { return nil, fmt.Errorf("net.Dial: %v", err) } return target, nil } - -func GetConnFD(conn net.Conn) (fd int, err error) { - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - return 0, errors.New("GetConnFD connection is not *net.TCPConn") - } - file, err := tcpConn.File() - if err != nil { - return 0, fmt.Errorf("tcpConn.File: %v", err) - } - - return int(file.Fd()), nil -} diff --git a/src/internal/server/base/tcp_linux.go b/src/internal/server/base/tcp_linux.go index 2ade219..aa1a370 100644 --- a/src/internal/server/base/tcp_linux.go +++ b/src/internal/server/base/tcp_linux.go @@ -3,6 +3,7 @@ package base import ( + "errors" "fmt" "net" "syscall" @@ -10,10 +11,8 @@ import ( "golang.org/x/sys/unix" ) -const SO_MARK = 0xc9 - -// ConnectWithMark dials the target address with SO_MARK set and returns the connection. -func ConnectWithMark(addr string, mark int) (target net.Conn, err error) { +// Connect dials the target address with SO_MARK set and returns the connection. +func Connect(addr string, mark int) (target net.Conn, err error) { dialer := net.Dialer{ Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { @@ -24,7 +23,7 @@ func ConnectWithMark(addr string, mark int) (target net.Conn, err error) { conn, err := dialer.Dial("tcp", addr) if err != nil { - return nil, fmt.Errorf("ConnectWithMark dialer.Dial SO_MARK(%d): %v", mark, err) + return nil, fmt.Errorf("Connect dialer.Dial SO_MARK(%d): %v", mark, err) } return conn, nil } @@ -44,3 +43,16 @@ func GetOriginalDstAddr(conn net.Conn) (addr string, err error) { port := uint16(raw.Multiaddr[2])<<8 + uint16(raw.Multiaddr[3]) return fmt.Sprintf("%s:%d", ip.String(), port), nil } + +func GetConnFD(conn net.Conn) (fd int, err error) { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return 0, errors.New("GetConnFD connection is not *net.TCPConn") + } + file, err := tcpConn.File() + if err != nil { + return 0, fmt.Errorf("tcpConn.File: %v", err) + } + + return int(file.Fd()), nil +} diff --git a/src/internal/server/http/http.go b/src/internal/server/http/http.go index 059021a..cdbdd1d 100644 --- a/src/internal/server/http/http.go +++ b/src/internal/server/http/http.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/log" + "github.com/sunbk201/ua3f/internal/netfilter" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/rule" "github.com/sunbk201/ua3f/internal/server/base" @@ -19,6 +20,7 @@ import ( type Server struct { base.Server + so_mark int } func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { @@ -29,6 +31,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, + so_mark: netfilter.SO_MARK, } } @@ -116,7 +119,7 @@ func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Requ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { slog.Info("HTTP CONNECT request", slog.String("host", req.Host)) destAddr := req.Host - dest, err := base.Connect(destAddr) + dest, err := base.Connect(destAddr, s.so_mark) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return diff --git a/src/internal/server/redirect/redirect_linux.go b/src/internal/server/redirect/redirect_linux.go index b91a396..85a6d35 100644 --- a/src/internal/server/redirect/redirect_linux.go +++ b/src/internal/server/redirect/redirect_linux.go @@ -99,10 +99,10 @@ func (s *Server) HandleClient(client net.Conn) { return } - target, err := base.ConnectWithMark(addr, s.so_mark) + target, err := base.Connect(addr, s.so_mark) if err != nil { _ = client.Close() - slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err)) + slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err)) return } diff --git a/src/internal/server/socks5/socks5.go b/src/internal/server/socks5/socks5.go index 3f1b5c4..206c17f 100644 --- a/src/internal/server/socks5/socks5.go +++ b/src/internal/server/socks5/socks5.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" "github.com/luyuhuang/subsocks/socks" "github.com/sunbk201/ua3f/internal/config" + "github.com/sunbk201/ua3f/internal/netfilter" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" "github.com/sunbk201/ua3f/internal/statistics" @@ -20,6 +21,7 @@ import ( type Server struct { base.Server listener net.Listener + so_mark int } func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { @@ -30,6 +32,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, + so_mark: netfilter.SO_MARK, } } @@ -135,12 +138,12 @@ func (s *Server) handleConnect(src net.Conn, req *socks.Request) error { srcAddr := src.RemoteAddr().String() destAddr := req.Addr.String() - dest, err := net.Dial("tcp", destAddr) + dest, err := base.Connect(destAddr, s.so_mark) if err != nil { if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil { slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) } - return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr) + return fmt.Errorf("base.Connect: %w, dest: %s", err, destAddr) } if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil { diff --git a/src/internal/server/tproxy/tproxy_linux.go b/src/internal/server/tproxy/tproxy_linux.go index e15a49b..475181a 100644 --- a/src/internal/server/tproxy/tproxy_linux.go +++ b/src/internal/server/tproxy/tproxy_linux.go @@ -130,10 +130,10 @@ func (s *Server) HandleClient(client net.Conn) { return } - target, err := base.ConnectWithMark(addr, s.so_mark) + target, err := base.Connect(addr, s.so_mark) if err != nil { _ = client.Close() - slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err)) + slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err)) return }