feat: add so_mark for socks5 and http

This commit is contained in:
SunBK201 2025-12-07 15:33:53 +08:00
parent db914c1f3c
commit 317dba5e78
6 changed files with 33 additions and 27 deletions

View File

@ -1,27 +1,15 @@
//go:build !linux
package base package base
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
) )
func Connect(addr string) (target net.Conn, err error) { func Connect(addr string, mark int) (target net.Conn, err error) {
if target, err = net.Dial("tcp", addr); err != nil { if target, err = net.Dial("tcp", addr); err != nil {
return nil, fmt.Errorf("net.Dial: %v", err) return nil, fmt.Errorf("net.Dial: %v", err)
} }
return target, nil return target, nil
} }
func GetConnFD(conn net.Conn) (fd int, err error) {
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
}
file, err := tcpConn.File()
if err != nil {
return 0, fmt.Errorf("tcpConn.File: %v", err)
}
return int(file.Fd()), nil
}

View File

@ -3,6 +3,7 @@
package base package base
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"syscall" "syscall"
@ -10,10 +11,8 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const SO_MARK = 0xc9 // Connect dials the target address with SO_MARK set and returns the connection.
func Connect(addr string, mark int) (target net.Conn, err error) {
// ConnectWithMark dials the target address with SO_MARK set and returns the connection.
func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
dialer := net.Dialer{ dialer := net.Dialer{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
@ -24,7 +23,7 @@ func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
conn, err := dialer.Dial("tcp", addr) conn, err := dialer.Dial("tcp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("ConnectWithMark dialer.Dial SO_MARK(%d): %v", mark, err) return nil, fmt.Errorf("Connect dialer.Dial SO_MARK(%d): %v", mark, err)
} }
return conn, nil return conn, nil
} }
@ -44,3 +43,16 @@ func GetOriginalDstAddr(conn net.Conn) (addr string, err error) {
port := uint16(raw.Multiaddr[2])<<8 + uint16(raw.Multiaddr[3]) port := uint16(raw.Multiaddr[2])<<8 + uint16(raw.Multiaddr[3])
return fmt.Sprintf("%s:%d", ip.String(), port), nil return fmt.Sprintf("%s:%d", ip.String(), port), nil
} }
func GetConnFD(conn net.Conn) (fd int, err error) {
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
}
file, err := tcpConn.File()
if err != nil {
return 0, fmt.Errorf("tcpConn.File: %v", err)
}
return int(file.Fd()), nil
}

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/golang-lru/v2/expirable"
"github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/log" "github.com/sunbk201/ua3f/internal/log"
"github.com/sunbk201/ua3f/internal/netfilter"
"github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/rewrite"
"github.com/sunbk201/ua3f/internal/rule" "github.com/sunbk201/ua3f/internal/rule"
"github.com/sunbk201/ua3f/internal/server/base" "github.com/sunbk201/ua3f/internal/server/base"
@ -19,6 +20,7 @@ import (
type Server struct { type Server struct {
base.Server base.Server
so_mark int
} }
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
@ -29,6 +31,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
Recorder: rc, Recorder: rc,
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
}, },
so_mark: netfilter.SO_MARK,
} }
} }
@ -116,7 +119,7 @@ func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Requ
func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
slog.Info("HTTP CONNECT request", slog.String("host", req.Host)) slog.Info("HTTP CONNECT request", slog.String("host", req.Host))
destAddr := req.Host destAddr := req.Host
dest, err := base.Connect(destAddr) dest, err := base.Connect(destAddr, s.so_mark)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable) http.Error(w, err.Error(), http.StatusServiceUnavailable)
return return

View File

@ -99,10 +99,10 @@ func (s *Server) HandleClient(client net.Conn) {
return return
} }
target, err := base.ConnectWithMark(addr, s.so_mark) target, err := base.Connect(addr, s.so_mark)
if err != nil { if err != nil {
_ = client.Close() _ = client.Close()
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err)) slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
return return
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/golang-lru/v2/expirable"
"github.com/luyuhuang/subsocks/socks" "github.com/luyuhuang/subsocks/socks"
"github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/netfilter"
"github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/rewrite"
"github.com/sunbk201/ua3f/internal/server/base" "github.com/sunbk201/ua3f/internal/server/base"
"github.com/sunbk201/ua3f/internal/statistics" "github.com/sunbk201/ua3f/internal/statistics"
@ -20,6 +21,7 @@ import (
type Server struct { type Server struct {
base.Server base.Server
listener net.Listener listener net.Listener
so_mark int
} }
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
@ -30,6 +32,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
Recorder: rc, Recorder: rc,
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
}, },
so_mark: netfilter.SO_MARK,
} }
} }
@ -135,12 +138,12 @@ func (s *Server) handleConnect(src net.Conn, req *socks.Request) error {
srcAddr := src.RemoteAddr().String() srcAddr := src.RemoteAddr().String()
destAddr := req.Addr.String() destAddr := req.Addr.String()
dest, err := net.Dial("tcp", destAddr) dest, err := base.Connect(destAddr, s.so_mark)
if err != nil { if err != nil {
if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil { if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err)) slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
} }
return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr) return fmt.Errorf("base.Connect: %w, dest: %s", err, destAddr)
} }
if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil { if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil {

View File

@ -130,10 +130,10 @@ func (s *Server) HandleClient(client net.Conn) {
return return
} }
target, err := base.ConnectWithMark(addr, s.so_mark) target, err := base.Connect(addr, s.so_mark)
if err != nil { if err != nil {
_ = client.Close() _ = client.Close()
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err)) slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
return return
} }