mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 16:57:08 +00:00
feat: add so_mark for socks5 and http
This commit is contained in:
parent
db914c1f3c
commit
317dba5e78
@ -1,27 +1,15 @@
|
||||
//go:build !linux
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func Connect(addr string) (target net.Conn, err error) {
|
||||
func Connect(addr string, mark int) (target net.Conn, err error) {
|
||||
if target, err = net.Dial("tcp", addr); err != nil {
|
||||
return nil, fmt.Errorf("net.Dial: %v", err)
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func GetConnFD(conn net.Conn) (fd int, err error) {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
|
||||
}
|
||||
file, err := tcpConn.File()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("tcpConn.File: %v", err)
|
||||
}
|
||||
|
||||
return int(file.Fd()), nil
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
@ -10,10 +11,8 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const SO_MARK = 0xc9
|
||||
|
||||
// ConnectWithMark dials the target address with SO_MARK set and returns the connection.
|
||||
func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
|
||||
// Connect dials the target address with SO_MARK set and returns the connection.
|
||||
func Connect(addr string, mark int) (target net.Conn, err error) {
|
||||
dialer := net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
@ -24,7 +23,7 @@ func ConnectWithMark(addr string, mark int) (target net.Conn, err error) {
|
||||
|
||||
conn, err := dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ConnectWithMark dialer.Dial SO_MARK(%d): %v", mark, err)
|
||||
return nil, fmt.Errorf("Connect dialer.Dial SO_MARK(%d): %v", mark, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
@ -44,3 +43,16 @@ func GetOriginalDstAddr(conn net.Conn) (addr string, err error) {
|
||||
port := uint16(raw.Multiaddr[2])<<8 + uint16(raw.Multiaddr[3])
|
||||
return fmt.Sprintf("%s:%d", ip.String(), port), nil
|
||||
}
|
||||
|
||||
func GetConnFD(conn net.Conn) (fd int, err error) {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return 0, errors.New("GetConnFD connection is not *net.TCPConn")
|
||||
}
|
||||
file, err := tcpConn.File()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("tcpConn.File: %v", err)
|
||||
}
|
||||
|
||||
return int(file.Fd()), nil
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/sunbk201/ua3f/internal/config"
|
||||
"github.com/sunbk201/ua3f/internal/log"
|
||||
"github.com/sunbk201/ua3f/internal/netfilter"
|
||||
"github.com/sunbk201/ua3f/internal/rewrite"
|
||||
"github.com/sunbk201/ua3f/internal/rule"
|
||||
"github.com/sunbk201/ua3f/internal/server/base"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
|
||||
type Server struct {
|
||||
base.Server
|
||||
so_mark int
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
|
||||
@ -29,6 +31,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
|
||||
Recorder: rc,
|
||||
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
|
||||
},
|
||||
so_mark: netfilter.SO_MARK,
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,7 +119,7 @@ func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Requ
|
||||
func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
|
||||
slog.Info("HTTP CONNECT request", slog.String("host", req.Host))
|
||||
destAddr := req.Host
|
||||
dest, err := base.Connect(destAddr)
|
||||
dest, err := base.Connect(destAddr, s.so_mark)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
|
||||
@ -99,10 +99,10 @@ func (s *Server) HandleClient(client net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
target, err := base.ConnectWithMark(addr, s.so_mark)
|
||||
target, err := base.Connect(addr, s.so_mark)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err))
|
||||
slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/luyuhuang/subsocks/socks"
|
||||
"github.com/sunbk201/ua3f/internal/config"
|
||||
"github.com/sunbk201/ua3f/internal/netfilter"
|
||||
"github.com/sunbk201/ua3f/internal/rewrite"
|
||||
"github.com/sunbk201/ua3f/internal/server/base"
|
||||
"github.com/sunbk201/ua3f/internal/statistics"
|
||||
@ -20,6 +21,7 @@ import (
|
||||
type Server struct {
|
||||
base.Server
|
||||
listener net.Listener
|
||||
so_mark int
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server {
|
||||
@ -30,6 +32,7 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser
|
||||
Recorder: rc,
|
||||
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
|
||||
},
|
||||
so_mark: netfilter.SO_MARK,
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,12 +138,12 @@ func (s *Server) handleConnect(src net.Conn, req *socks.Request) error {
|
||||
srcAddr := src.RemoteAddr().String()
|
||||
destAddr := req.Addr.String()
|
||||
|
||||
dest, err := net.Dial("tcp", destAddr)
|
||||
dest, err := base.Connect(destAddr, s.so_mark)
|
||||
if err != nil {
|
||||
if err := socks.NewReply(socks.HostUnreachable, nil).Write(src); err != nil {
|
||||
slog.Error("socks.NewReply.Write", slog.String("srcAddr", srcAddr), slog.Any("error", err))
|
||||
}
|
||||
return fmt.Errorf("net.Dial: %w, dest: %s", err, destAddr)
|
||||
return fmt.Errorf("base.Connect: %w, dest: %s", err, destAddr)
|
||||
}
|
||||
|
||||
if err := socks.NewReply(socks.Succeeded, nil).Write(src); err != nil {
|
||||
|
||||
@ -130,10 +130,10 @@ func (s *Server) HandleClient(client net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
target, err := base.ConnectWithMark(addr, s.so_mark)
|
||||
target, err := base.Connect(addr, s.so_mark)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
slog.Warn("base.ConnectWithMark", slog.String("addr", addr), slog.Any("error", err))
|
||||
slog.Warn("base.Connect", slog.String("addr", addr), slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user