feat: add so_mark for socks5 and http

This commit is contained in:
SunBK201 2025-12-07 15:33:53 +08:00
parent db914c1f3c
commit 317dba5e78
6 changed files with 33 additions and 27 deletions

View File

@ -1,27 +1,15 @@
//go:build !linux
package base
import (
"errors"
"fmt"
"net"
)
func Connect(addr string) (target net.Conn, err error) {
func Connect(addr string, mark int) (target net.Conn, err error) {
if target, err = net.Dial("tcp", addr); err != nil {
return nil, fmt.Errorf("net.Dial: %v", err)
}
return target, nil
}
func GetConnFD(conn net.Conn) (fd int, err error) {
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
}
file, err := tcpConn.File()
if err != nil {
return 0, fmt.Errorf("tcpConn.File: %v", err)
}
return int(file.Fd()), nil
}

View File

@ -3,6 +3,7 @@
package base
import (
"errors"
"fmt"
"net"
"syscall"
@ -10,10 +11,8 @@ import (
"golang.org/x/sys/unix"
)
const SO_MARK = 0xc9
// ConnectWithMark dials the target address with SO_MARK set and returns the connection.
func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
// Connect dials the target address with SO_MARK set and returns the connection.
func Connect(addr string, mark int) (target net.Conn, err error) {
dialer := net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
@ -24,7 +23,7 @@ func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("ConnectWithMark dialer.Dial SO_MARK(%d): %v", mark, err)
return nil, fmt.Errorf("Connect dialer.Dial SO_MARK(%d): %v", mark, err)
}
return conn, nil
}
@ -44,3 +43,16 @@ func GetOriginalDstAddr(conn net.Conn) (addr string, err error) {
port := uint16(raw.Multiaddr[2])<<8 + uint16(raw.Multiaddr[3])
return fmt.Sprintf("%s:%d", ip.String(), port), nil
}
func GetConnFD(conn net.Conn) (fd int, err error) {
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
}
file, err := tcpConn.File()
if err != nil {
return 0, fmt.Errorf("tcpConn.File: %v", err)
}
return int(file.Fd()), nil
}

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/log"
"github.com/sunbk201/ua3f/internal/netfilter"
"github.com/sunbk201/ua3f/internal/rewrite"
"github.com/sunbk201/ua3f/internal/rule"
"github.com/sunbk201/ua3f/internal/server/base"
@ -19,6 +20,7 @@ import (
type Server struct {
base.Server
so_mark int
}
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
@ -29,6 +31,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
Recorder: rc,
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
},
so_mark: netfilter.SO_MARK,
}
}
@ -116,7 +119,7 @@ func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Requ
func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
slog.Info("HTTP CONNECT request", slog.String("host", req.Host))
destAddr := req.Host
dest, err := base.Connect(destAddr)
dest, err := base.Connect(destAddr, s.so_mark)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return

View File

@ -99,10 +99,10 @@ func (s *Server) HandleClient(client net.Conn) {
return
}
target, err := base.ConnectWithMark(addr, s.so_mark)
target, err := base.Connect(addr, s.so_mark)
if err != nil {
_ = client.Close()
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err))
slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
return
}

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/luyuhuang/subsocks/socks"
"github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/netfilter"
"github.com/sunbk201/ua3f/internal/rewrite"
"github.com/sunbk201/ua3f/internal/server/base"
"github.com/sunbk201/ua3f/internal/statistics"
@ -20,6 +21,7 @@ import (
type Server struct {
base.Server
listener net.Listener
so_mark int
}
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
@ -30,6 +32,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
Recorder: rc,
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
},
so_mark: netfilter.SO_MARK,
}
}
@ -135,12 +138,12 @@ func (s *Server) handleConnect(src net.Conn, req *socks.Request) error {
srcAddr := src.RemoteAddr().String()
destAddr := req.Addr.String()
dest, err := net.Dial("tcp", destAddr)
dest, err := base.Connect(destAddr, s.so_mark)
if err != nil {
if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil {
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
}
return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr)
return fmt.Errorf("base.Connect: %w, dest: %s", err, destAddr)
}
if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil {

View File

@ -130,10 +130,10 @@ func (s *Server) HandleClient(client net.Conn) {
return
}
target, err := base.ConnectWithMark(addr, s.so_mark)
target, err := base.Connect(addr, s.so_mark)
if err != nil {
_ = client.Close()
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err))
slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
return
}