From 153628808dd269eb4611cf214e8b747833fba62b Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Mon, 27 Oct 2025 18:06:23 +0800 Subject: [PATCH] refactor: code tidy --- src/internal/config/config.go | 43 +++ src/internal/log/log.go | 10 + src/internal/rewrite/rewriter.go | 229 ++++++++++++ src/internal/{ => server}/http/http.go | 0 src/internal/server/socks5/socks5.go | 399 ++++++++++++++++++++ src/main.go | 482 +------------------------ 6 files changed, 695 insertions(+), 468 deletions(-) create mode 100644 src/internal/config/config.go create mode 100644 src/internal/rewrite/rewriter.go rename src/internal/{ => server}/http/http.go (100%) create mode 100644 src/internal/server/socks5/socks5.go diff --git a/src/internal/config/config.go b/src/internal/config/config.go new file mode 100644 index 0000000..a2da344 --- /dev/null +++ b/src/internal/config/config.go @@ -0,0 +1,43 @@ +package config + +import "flag" + +type Config struct { + BindAddr string + Port int + LogLevel string + PayloadUA string + UAPattern string + EnablePartialReplace bool +} + +func Parse() (*Config, bool) { + var ( + bindAddr string + port int + loglevel string + payloadUA string + uaPattern string + partial bool + showVer bool + ) + + flag.StringVar(&bindAddr, "b", "127.0.0.1", "bind address (default: 127.0.0.1)") + flag.IntVar(&port, "p", 1080, "port") + flag.StringVar(&payloadUA, "f", "FFF", "User-Agent") + flag.StringVar(&uaPattern, "r", "", "UA-Pattern") + flag.BoolVar(&partial, "s", false, "Enable Regex Partial Replace") + flag.StringVar(&loglevel, "l", "info", "Log level (default: info)") + flag.BoolVar(&showVer, "v", false, "show version") + flag.Parse() + + cfg := &Config{ + BindAddr: bindAddr, + Port: port, + LogLevel: loglevel, + PayloadUA: payloadUA, + UAPattern: uaPattern, + EnablePartialReplace: partial, + } + return cfg, showVer +} diff --git a/src/internal/log/log.go b/src/internal/log/log.go index bbea0ce..8cfcffc 100644 --- a/src/internal/log/log.go +++ b/src/internal/log/log.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/sirupsen/logrus" + "github.com/sunbk201/ua3f/internal/config" "gopkg.in/natefinch/lumberjack.v2" ) @@ -61,3 +62,12 @@ func SetLogConf(level string) { logrus.SetLevel(logrus.InfoLevel) } } + +func LogHeader(version string, addr string, cfg *config.Config) { + logrus.Info("UA3F v" + version) + logrus.Infof("Listen on %s", addr) + logrus.Infof("User-Agent: %s", cfg.PayloadUA) + logrus.Infof("User-Agent Regex Pattern: '%s'", cfg.UAPattern) + logrus.Infof("Enable Partial Replace: %v", cfg.EnablePartialReplace) + logrus.Infof("Log level: %s", cfg.LogLevel) +} diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go new file mode 100644 index 0000000..7ddcdcc --- /dev/null +++ b/src/internal/rewrite/rewriter.go @@ -0,0 +1,229 @@ +package rewrite + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/dlclark/regexp2" + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/sirupsen/logrus" + + "github.com/sunbk201/ua3f/internal/config" + "github.com/sunbk201/ua3f/internal/statistics" +) + +const ( + ErrUseClosedConn = "use of closed network connection" + ErrConnResetByPeer = "connection reset by peer" + ErrIOTimeout = "i/o timeout" +) + +// HTTP methods used to detect HTTP by request line. +var httpMethods = []string{"GET", "POST", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", "CONNECT"} + +// Hardcoded whitelist of UAs that should be left untouched. +var defaultWhitelist = []string{ + "MicroMessenger Client", + "ByteDancePcdn", + "Go-http-client/1.1", +} + +// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache. +type Rewriter struct { + payloadUA string + pattern string + enablePartialReplace bool + + uaRegex *regexp2.Regexp + cache *expirable.LRU[string, string] + whitelist map[string]struct{} +} + +// New constructs a Rewriter from config. Compiles regex and allocates cache. +func New(cfg *config.Config) (*Rewriter, error) { + // UA pattern is compiled with case-insensitive prefix (?i) + pattern := "(?i)" + cfg.UAPattern + uaRegex, err := regexp2.Compile(pattern, regexp2.None) + if err != nil { + return nil, err + } + + cache := expirable.NewLRU[string, string](300, nil, 10*time.Minute) + + whitelist := make(map[string]struct{}, len(defaultWhitelist)) + for _, s := range defaultWhitelist { + whitelist[s] = struct{}{} + } + + return &Rewriter{ + payloadUA: cfg.PayloadUA, + pattern: cfg.UAPattern, + enablePartialReplace: cfg.EnablePartialReplace, + uaRegex: uaRegex, + cache: cache, + whitelist: whitelist, + }, nil +} + +// ProxyHTTPOrRaw reads traffic from src and writes to dst. +// - If target in LRU cache: pass-through (raw). +// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched). +// - Else: mark target in LRU and pass-through. +func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort string) error { + // Fast path: known pass-through + if r.cache.Contains(destAddrPort) { + logrus.Debugf("Hit LRU Relay Cache: %s", destAddrPort) + _, _ = io.Copy(dst, src) + return nil + } + + reader := bufio.NewReader(src) + + isHTTP, err := r.isHTTP(reader) + if err != nil { + if strings.Contains(err.Error(), ErrUseClosedConn) { + logrus.Warnf("[%s] isHTTP error: %s", destAddrPort, err.Error()) + return err + } + // Other read errors terminate the direction. + return err + } + + if !isHTTP { + r.cache.Add(destAddrPort, destAddrPort) + logrus.Debugf("Not HTTP, Add LRU Relay Cache: %s, Cache Len: %d", destAddrPort, r.cache.Len()) + _, _ = io.Copy(dst, reader) + return nil + } + + srcAddr := src.RemoteAddr().String() + + // HTTP request loop (handles keep-alive) + for { + req, err := http.ReadRequest(reader) + if err != nil { + r.logReadErr(destAddrPort, src, err) + return err + } + + originalUA := req.Header.Get("User-Agent") + + // No UA header: pass-through after writing this first request + if originalUA == "" { + r.cache.Add(destAddrPort, destAddrPort) + logrus.Debugf("[%s] Not found User-Agent, Add LRU Relay Cache, Cache Len: %d", + destAddrPort, r.cache.Len()) + if err := req.Write(dst); err != nil { + logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error()) + return err + } + _, _ = io.Copy(dst, reader) + return nil + } + + isWhitelist := r.inWhitelist(originalUA) + matches := true + if r.pattern != "" { + matches, err = r.uaRegex.MatchString(originalUA) + if err != nil { + logrus.Errorf("[%s][%s] User-Agent Regex Pattern Match Error: %s", + destAddrPort, srcAddr, err.Error()) + matches = true + } + } + + // If UA is whitelisted or does not match target pattern, write once then pass-through. + if isWhitelist || !matches { + if !matches { + logrus.Debugf("[%s][%s] Not Hit User-Agent Pattern: %s", + destAddrPort, srcAddr, originalUA) + } + if isWhitelist { + logrus.Debugf("[%s][%s] Hit User-Agent Whitelist: %s, Add LRU Relay Cache, Cache Len: %d", + destAddrPort, srcAddr, originalUA, r.cache.Len()) + r.cache.Add(destAddrPort, destAddrPort) + } + if err := req.Write(dst); err != nil { + logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error()) + return err + } + _, _ = io.Copy(dst, reader) + return nil + } + + // Rewrite UA and forward the request (including body) + logrus.Debugf("[%s][%s] Hit User-Agent: %s", destAddrPort, srcAddr, originalUA) + mockedUA := r.buildNewUA(originalUA) + req.Header.Set("User-Agent", mockedUA) + if err := req.Write(dst); err != nil { + logrus.Errorf("[%s][%s] write error after replace user-agent: %s", + destAddrPort, srcAddr, err.Error()) + return err + } + + statistics.AddStat(&statistics.StatRecord{ + Host: destAddrPort, + OriginUA: originalUA, + MockedUA: mockedUA, + }) + } +} + +// isHTTP peeks the first few bytes and checks for a known HTTP method prefix. +func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) { + buf, err := reader.Peek(7) + if err != nil { + if strings.Contains(err.Error(), "EOF") { + logrus.Debugf("Peek EOF: %s", err.Error()) + } else { + logrus.Errorf("Peek error: %s", err.Error()) + } + return false, err + } + hint := string(buf) + for _, m := range httpMethods { + if strings.HasPrefix(hint, m) { + return true, nil + } + } + return false, nil +} + +// buildNewUA returns either a partial replacement (regex) or full overwrite. +func (r *Rewriter) buildNewUA(originUA string) string { + if r.enablePartialReplace && r.uaRegex != nil { + newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1) + if err != nil { + logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error()) + return r.payloadUA + } + return newUA + } + return r.payloadUA +} + +func (r *Rewriter) inWhitelist(ua string) bool { + _, ok := r.whitelist[ua] + return ok +} + +func (r *Rewriter) logReadErr(destAddrPort string, src net.Conn, err error) { + remote := src.RemoteAddr().String() + switch { + case errors.Is(err, io.EOF): + logrus.Debugf("[%s][%s] read EOF in first phase", destAddrPort, remote) + case strings.Contains(err.Error(), ErrUseClosedConn): + logrus.Debugf("[%s][%s] read closed in first phase: %s", destAddrPort, remote, err.Error()) + case strings.Contains(err.Error(), ErrConnResetByPeer): + logrus.Debugf("[%s][%s] read reset in first phase: %s", destAddrPort, remote, err.Error()) + case strings.Contains(err.Error(), ErrIOTimeout): + logrus.Debugf("[%s][%s] read timeout in first phase: %s", destAddrPort, remote, err.Error()) + default: + logrus.Errorf("[%s][%s] read error in first phase: %s", destAddrPort, remote, err.Error()) + } +} diff --git a/src/internal/http/http.go b/src/internal/server/http/http.go similarity index 100% rename from src/internal/http/http.go rename to src/internal/server/http/http.go diff --git a/src/internal/server/socks5/socks5.go b/src/internal/server/socks5/socks5.go new file mode 100644 index 0000000..bfc4242 --- /dev/null +++ b/src/internal/server/socks5/socks5.go @@ -0,0 +1,399 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/sunbk201/ua3f/internal/config" + "github.com/sunbk201/ua3f/internal/rewrite" + "github.com/sunbk201/ua3f/internal/statistics" +) + +// SOCKS5 constants +const ( + socksVer5 = 0x05 + socksNoAuth = 0x00 + socksCmdConn = 0x01 + socksCmdUDP = 0x03 + + socksATYPv4 = 0x01 + socksATYDomain = 0x03 + socksATYPv6 = 0x04 +) + +var ( + ErrInvalidSocksVersion = errors.New("invalid socks version") + ErrInvalidSocksCmd = errors.New("invalid socks cmd") +) + +// Server is a minimal SOCKS5 server that delegates HTTP UA rewriting to Rewriter. +type Server struct { + cfg *config.Config + rw *rewrite.Rewriter + listener net.Listener + ListenAddr string +} + +// New returns a new Server with given config, rewriter, and version string. +func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { + return &Server{ + cfg: cfg, + rw: rw, + ListenAddr: fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.Port), + } +} + +// Start begins listening for SOCKS5 clients. +func (s *Server) Start() (err error) { + if s.listener, err = net.Listen("tcp", s.ListenAddr); err != nil { + return fmt.Errorf("listen failed: %w", err) + } + + // Start statistics worker + go statistics.StartStatWorker() + + var client net.Conn + for { + if client, err = s.listener.Accept(); err != nil { + logrus.Error("Accept failed: ", err) + continue + } + logrus.Debugf("Accept connection from %s", client.RemoteAddr().String()) + go s.handleClient(client) + } +} + +// handleClient performs SOCKS5 negotiation and dispatches TCP/UDP handling. +func (s *Server) handleClient(client net.Conn) { + // Handshake (no auth) + if err := s.socks5Auth(client); err != nil { + _ = client.Close() + return + } + + destAddrPort, cmd, err := s.parseSocks5Request(client) + if err != nil { + if cmd == socksCmdUDP { + // UDP Associate + s.handleUDPAssociate(client) + _ = client.Close() + return + } + logrus.Debugf("[%s][%s] ParseSocks5Request failed: %s", + client.RemoteAddr().String(), destAddrPort, err.Error()) + _ = client.Close() + return + } + + // TCP CONNECT + target, err := s.socks5Connect(client, destAddrPort) + if err != nil { + logrus.Debug("Connect failed: ", err) + _ = client.Close() + return + } + s.forwardTCP(client, target, destAddrPort) +} + +// socks5Auth performs a minimal "no-auth" negotiation. +func (s *Server) socks5Auth(client net.Conn) error { + buf := make([]byte, 256) + + // Read VER, NMETHODS + n, err := io.ReadFull(client, buf[:2]) + if n != 2 { + if errors.Is(err, io.EOF) { + logrus.Warnf("[%s][Auth] read EOF", client.RemoteAddr().String()) + } else { + logrus.Errorf("[%s][Auth] read header: %v", client.RemoteAddr().String(), err) + } + return fmt.Errorf("reading header: %w", err) + } + ver, nMethods := int(buf[0]), int(buf[1]) + if ver != socksVer5 { + logrus.Errorf("[%s][Auth] invalid ver", client.RemoteAddr().String()) + return ErrInvalidSocksVersion + } + + // Read METHODS + n, err = io.ReadFull(client, buf[:nMethods]) + if n != nMethods { + logrus.Errorf("[%s][Auth] read methods: %v", client.RemoteAddr().String(), err) + return fmt.Errorf("read methods: %w", err) + } + + // Reply: no-auth + n, err = client.Write([]byte{socksVer5, socksNoAuth}) + if n != 2 || err != nil { + logrus.Errorf("[%s][Auth] write rsp: %v", client.RemoteAddr().String(), err) + return fmt.Errorf("write rsp: %w", err) + } + return nil +} + +// parseSocks5Request reads a single SOCKS5 request. Returns dest, cmd, and error. +func (s *Server) parseSocks5Request(client net.Conn) (string, byte, error) { + buf := make([]byte, 256) + + // VER, CMD, RSV, ATYP + if _, err := io.ReadFull(client, buf[:4]); err != nil { + return "", 0, fmt.Errorf("read header: %w", err) + } + ver, cmd, atyp := buf[0], buf[1], buf[3] + if ver != socksVer5 { + return "", cmd, ErrInvalidSocksVersion + } + + // UDP associate: let caller handle + if cmd == socksCmdUDP { + return "", socksCmdUDP, errors.New("UDP Associate") + } + if cmd != socksCmdConn { + return "", cmd, ErrInvalidSocksCmd + } + + var addr string + switch atyp { + case socksATYPv4: + if _, err := io.ReadFull(client, buf[:4]); err != nil { + return "", cmd, fmt.Errorf("invalid IPv4: %w", err) + } + addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]) + + case socksATYDomain: + if _, err := io.ReadFull(client, buf[:1]); err != nil { + return "", cmd, fmt.Errorf("invalid hostname(len): %w", err) + } + addrLen := int(buf[0]) + if _, err := io.ReadFull(client, buf[:addrLen]); err != nil { + return "", cmd, fmt.Errorf("invalid hostname: %w", err) + } + addr = string(buf[:addrLen]) + + case socksATYPv6: + return "", cmd, errors.New("IPv6: not supported yet") + + default: + return "", cmd, errors.New("invalid atyp") + } + + if _, err := io.ReadFull(client, buf[:2]); err != nil { + return "", cmd, fmt.Errorf("read port: %w", err) + } + port := binary.BigEndian.Uint16(buf[:2]) + + return fmt.Sprintf("%s:%d", addr, port), cmd, nil +} + +// socks5Connect dials the target and responds success to the client. +func (s *Server) socks5Connect(client net.Conn, destAddrPort string) (net.Conn, error) { + logrus.Debugf("Connecting %s", destAddrPort) + target, err := net.Dial("tcp", destAddrPort) + if err != nil { + return nil, err + } + logrus.Debugf("Connected %s", destAddrPort) + + // Reply success (bind set to 0.0.0.0:0) + if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, 0, 0}); err != nil { + _ = target.Close() + return nil, err + } + return target, nil +} + +// forwardTCP proxies traffic in both directions. +// target->client uses raw copy. +// client->target is processed by the rewriter (or raw if cached). +func (s *Server) forwardTCP(client, target net.Conn, destAddrPort string) { + // Server -> Client (raw) + go s.copyHalf(client, target) + + // Client -> Server (rewriter) + go s.proxyHalf(target, client, destAddrPort) +} + +// copyHalf copies from src to dst and half-closes both sides when done. +func (s *Server) copyHalf(dst, src net.Conn) { + defer func() { + // Prefer TCP half-close to allow the opposite direction to drain. + if tc, ok := dst.(*net.TCPConn); ok { + _ = tc.CloseWrite() + } else { + _ = dst.Close() + } + if tc, ok := src.(*net.TCPConn); ok { + _ = tc.CloseRead() + } else { + _ = src.Close() + } + }() + _, _ = io.Copy(dst, src) +} + +// proxyHalf runs the rewriter proxy on src->dst and then half-closes both sides. +func (s *Server) proxyHalf(dst, src net.Conn, destAddrPort string) { + defer func() { + if tc, ok := dst.(*net.TCPConn); ok { + _ = tc.CloseWrite() + } else { + _ = dst.Close() + } + if tc, ok := src.(*net.TCPConn); ok { + _ = tc.CloseRead() + } else { + _ = src.Close() + } + }() + _ = s.rw.ProxyHTTPOrRaw(dst, src, destAddrPort) +} + +// handleUDPAssociate handles a UDP ASSOCIATE request by creating a UDP relay socket. +// Only IPv4 and domain ATYP are supported (no IPv6). +func (s *Server) handleUDPAssociate(client net.Conn) { + udpServer, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + logrus.Errorf("[%s][UDP] ListenUDP failed: %v", client.RemoteAddr().String(), err) + return + } + defer udpServer.Close() + + _, portStr, _ := net.SplitHostPort(udpServer.LocalAddr().String()) + logrus.Debugf("[%s][UDP] ListenUDP on %s", client.RemoteAddr().String(), portStr) + + portInt, _ := net.LookupPort("udp", portStr) + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(portInt)) + + // Reply with chosen UDP port (bind addr set to 0.0.0.0) + if _, err = client.Write([]byte{socksVer5, 0x00, 0x00, socksATYPv4, 0, 0, 0, 0, portBytes[0], portBytes[1]}); err != nil { + logrus.Errorf("[%s][UDP] Write rsp failed: %v", client.RemoteAddr().String(), err) + return + } + + buf := make([]byte, 65535) + udpPortMap := make(map[string][]byte) + var clientAddr *net.UDPAddr + isDomain := false + + for { + _ = udpServer.SetReadDeadline(time.Now().Add(10 * time.Second)) + n, fromAddr, err := udpServer.ReadFromUDP(buf) + if err != nil { + if strings.Contains(err.Error(), "i/o timeout") { + logrus.Debugf("[%s][UDP] ReadFromUDP timeout: %v", client.RemoteAddr().String(), err) + if !isAlive(client) { + logrus.Debugf("[%s][UDP] client is not alive", client.RemoteAddr().String()) + break + } + } else { + logrus.Errorf("[%s][UDP] ReadFromUDP failed: %v", client.RemoteAddr().String(), err) + } + continue + } + if clientAddr == nil { + clientAddr = fromAddr + } + + if clientAddr.IP.Equal(fromAddr.IP) && clientAddr.Port == fromAddr.Port { + // Packet from client -> forward to remote + atyp := buf[3] + var ( + targetAddr string + targetPort uint16 + payload []byte + header []byte + targetIP net.IP + ) + + switch atyp { + case socksATYPv4: + isDomain = false + targetAddr = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7]) + targetIP = net.ParseIP(targetAddr) + targetPort = binary.BigEndian.Uint16(buf[8:10]) + payload = buf[10:n] + header = buf[0:10] + + case socksATYDomain: + isDomain = true + addrLen := int(buf[4]) + targetAddr = string(buf[5 : 5+addrLen]) + targetIPAddr, err := net.ResolveIPAddr("ip", targetAddr) + if err != nil { + logrus.Errorf("[%s][UDP] ResolveIPAddr failed: %v", client.RemoteAddr().String(), err) + break + } + targetIP = targetIPAddr.IP + targetPort = binary.BigEndian.Uint16(buf[5+addrLen : 5+addrLen+2]) + payload = buf[5+addrLen+2 : n] + header = buf[0 : 5+addrLen+2] + + case socksATYPv6: + logrus.Errorf("[%s][UDP] IPv6: not supported yet", client.RemoteAddr().String()) + break + + default: + logrus.Errorf("[%s][UDP] invalid atyp", client.RemoteAddr().String()) + break + } + + remoteAddr := &net.UDPAddr{IP: targetIP, Port: int(targetPort)} + udpPortMap[remoteAddr.String()] = make([]byte, len(header)) + copy(udpPortMap[remoteAddr.String()], header) + + _ = udpServer.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if _, err = udpServer.WriteToUDP(payload, remoteAddr); err != nil { + logrus.Debugf("[%s][UDP] WriteToUDP to remote failed: %v", client.RemoteAddr().String(), err) + continue + } + } else { + // Packet from remote -> forward to client (rebuild header) + header := udpPortMap[fromAddr.String()] + if header == nil { + logrus.Errorf("[%s][UDP] udpPortMap invalid header", client.RemoteAddr().String()) + continue + } + // For domain ATYP, preserve original head section size + if isDomain { + header = header[0:4] + } + body := append(header, buf[:n]...) + if _, err = udpServer.WriteToUDP(body, clientAddr); err != nil { + logrus.Debugf("[%s][UDP] WriteToUDP to client failed: %v", client.RemoteAddr().String(), err) + continue + } + } + } +} + +// isAlive checks if a connection is still alive using a short read deadline. +func isAlive(conn net.Conn) bool { + one := make([]byte, 1) + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, err := conn.Read(one) + if err != nil { + switch { + case errors.Is(err, io.EOF): + logrus.Debugf("[%s] isAlive: EOF", conn.RemoteAddr().String()) + return false + case strings.Contains(err.Error(), "use of closed network connection"): + logrus.Debugf("[%s] isAlive: closed", conn.RemoteAddr().String()) + return false + case strings.Contains(err.Error(), "i/o timeout"): + logrus.Debugf("[%s] isAlive: timeout", conn.RemoteAddr().String()) + return true + default: + logrus.Debugf("[%s] isAlive: %s", conn.RemoteAddr().String(), err.Error()) + return false + } + } + _ = conn.SetReadDeadline(time.Time{}) + return true +} diff --git a/src/main.go b/src/main.go index c9e9328..6e97310 100644 --- a/src/main.go +++ b/src/main.go @@ -1,488 +1,34 @@ package main import ( - "bufio" - "encoding/binary" - "errors" - "flag" "fmt" - "io" - "net" - "net/http" - "strings" - "time" - "github.com/dlclark/regexp2" - "github.com/hashicorp/golang-lru/v2/expirable" "github.com/sirupsen/logrus" + + "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/log" - "github.com/sunbk201/ua3f/internal/statistics" + "github.com/sunbk201/ua3f/internal/rewrite" + "github.com/sunbk201/ua3f/internal/server/socks5" ) -var ( - ErrInvalidSocksVersion = errors.New("invalid socks version") - ErrInvalidSocksCmd = errors.New("invalid socks cmd") -) - -var version = "0.8.0" -var payload string -var uaPattern string -var uaRegexp *regexp2.Regexp -var enablePartialReplace bool -var cache *expirable.LRU[string, string] -var HTTP_METHOD = []string{"GET", "POST", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", "CONNECT"} -var whitelist = []string{ - "MicroMessenger Client", - "ByteDancePcdn", - "Go-http-client/1.1", -} +const version = "0.8.0" func main() { - var addr string - var port int - var loglevel string - - flag.StringVar(&addr, "b", "127.0.0.1", "bind address (default: 127.0.0.1)") - flag.IntVar(&port, "p", 1080, "port") - flag.StringVar(&payload, "f", "FFF", "User-Agent") - flag.StringVar(&uaPattern, "r", "(iPhone|iPad|Android|Macintosh|Windows|Linux|Apple|Mac OS X|Mobile)", "UA-Pattern") - flag.BoolVar(&enablePartialReplace, "s", false, "Enable Regex Partial Replace") - flag.StringVar(&loglevel, "l", "info", "Log level (default: info)") - flag.Bool("v", false, "show version") - flag.Parse() - - if flag.Lookup("v").Value.String() == "true" { + cfg, showVer := config.Parse() + if showVer { fmt.Println("UA3F v" + version) return } + log.SetLogConf(cfg.LogLevel) - log.SetLogConf(loglevel) - logrus.Info("UA3F v" + version) - logrus.Info(fmt.Sprintf("Port: %d", port)) - logrus.Info(fmt.Sprintf("User-Agent: %s", payload)) - logrus.Info(fmt.Sprintf("User-Agent Regex Pattern: %s", uaPattern)) - logrus.Info(fmt.Sprintf("Enable Partial Replace: %v", enablePartialReplace)) - logrus.Info(fmt.Sprintf("Log level: %s", loglevel)) - - cache = expirable.NewLRU[string, string](300, nil, time.Second*600) - - server, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port)) + rw, err := rewrite.New(cfg) if err != nil { - logrus.Fatal("Listen failed: ", err) - return - } - logrus.Info(fmt.Sprintf("Listen on %s:%d", addr, port)) - - // ignore case - uaPattern = "(?i)" + uaPattern - uaRegexp, err = regexp2.Compile(uaPattern, regexp2.None) - if err != nil { - logrus.Fatal("Invalid User-Agent Regex Pattern: ", err) - return + logrus.Fatal(err) } - go statistics.StartStatWorker() - - for { - client, err := server.Accept() - if err != nil { - logrus.Error("Accept failed: ", err) - continue - } - logrus.Debug(fmt.Sprintf("Accept %s", client.RemoteAddr().String())) - go process(client) - } -} - -func process(client net.Conn) { - if err := Socks5Auth(client); err != nil { - client.Close() - return - } - destAddrPort, err := ParseSocks5Request(client) - if err != nil { - if strings.Contains(err.Error(), "UDP Associate") { - Socks5UDP(client) - client.Close() - return - } else { - logrus.Debug(fmt.Sprintf("[%s][%s] ParseSocks5Request failed: %s", client.RemoteAddr().String(), destAddrPort, err.Error())) - client.Close() - return - } - } - target, err := Socks5Connect(client, destAddrPort) - if err != nil { - logrus.Debug("Connect failed: ", err) - client.Close() - return - } - Socks5Forward(client, target, destAddrPort) -} - -func Socks5Auth(client net.Conn) (err error) { - buf := make([]byte, 256) - n, err := io.ReadFull(client, buf[:2]) - if n != 2 { - if errors.Is(err, io.EOF) { - logrus.Warn(fmt.Sprintf("[%s][Auth] read EOF", client.RemoteAddr().String())) - } else { - logrus.Error(fmt.Sprintf("[%s][Auth] read header: %s", client.RemoteAddr().String(), err.Error())) - } - return errors.New("reading header:" + err.Error()) - } - ver, nMethods := int(buf[0]), int(buf[1]) - if ver != 5 { - logrus.Error(fmt.Sprintf("[%s][Auth] invalid ver", client.RemoteAddr().String())) - return ErrInvalidSocksVersion - } - n, err = io.ReadFull(client, buf[:nMethods]) - if n != nMethods { - logrus.Error(fmt.Sprintf("[%s][Auth] read methods: %s", client.RemoteAddr().String(), err.Error())) - return errors.New("read methods:" + err.Error()) - } - n, err = client.Write([]byte{0x05, 0x00}) - if n != 2 || err != nil { - logrus.Error(fmt.Sprintf("[%s][Auth] write rsp: %s", client.RemoteAddr().String(), err.Error())) - return errors.New("write rsp:" + err.Error()) - } - return nil -} - -func isAlive(conn net.Conn) bool { - one := make([]byte, 1) - conn.SetReadDeadline(time.Now().Add(time.Second * 5)) - _, err := conn.Read(one) - if err != nil { - if err == io.EOF { - logrus.Debug(fmt.Sprintf("[%s] isAlive: EOF", conn.RemoteAddr().String())) - return false - } else if strings.Contains(err.Error(), "use of closed network connection") { - logrus.Debug(fmt.Sprintf("[%s] isAlive: closed", conn.RemoteAddr().String())) - return false - } else if strings.Contains(err.Error(), "i/o timeout") { - logrus.Debug(fmt.Sprintf("[%s] isAlive: timeout", conn.RemoteAddr().String())) - return true - } else { - logrus.Debug(fmt.Sprintf("[%s] isAlive: %s", conn.RemoteAddr().String(), err.Error())) - return false - } - } - conn.SetReadDeadline(time.Time{}) - return true -} - -func Socks5UDP(client net.Conn) { - udpserver, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][UDP] ListenUDP failed: %s", client.RemoteAddr().String(), err.Error())) - return - } - _, port, _ := net.SplitHostPort(udpserver.LocalAddr().String()) - logrus.Debug(fmt.Sprintf("[%s][UDP] ListenUDP on %s", client.RemoteAddr().String(), port)) - portInt, _ := net.LookupPort("udp", port) - portBytes := make([]byte, 2) - binary.BigEndian.PutUint16(portBytes, uint16(portInt)) - _, err = client.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, portBytes[0], portBytes[1]}) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][UDP] Write rsp failed: %s", client.RemoteAddr().String(), err.Error())) - return - } - buf := make([]byte, 65535) - udpPortMap := make(map[string][]byte) - var clientAddr *net.UDPAddr - var isDomain bool = false - for { - udpserver.SetReadDeadline(time.Now().Add(time.Second * 10)) - n, fromAddr, err := udpserver.ReadFromUDP(buf) - if err != nil { - if strings.Contains(err.Error(), "i/o timeout") { - logrus.Debug(fmt.Sprintf("[%s][UDP] ReadFromUDP failed: %s", client.RemoteAddr().String(), err.Error())) - if !isAlive(client) { - logrus.Debug(fmt.Sprintf("[%s][UDP] client is not alive", client.RemoteAddr().String())) - break - } - } else { - logrus.Error(fmt.Sprintf("[%s][UDP] ReadFromUDP failed: %s", client.RemoteAddr().String(), err.Error())) - } - continue - } - if clientAddr == nil { - clientAddr = fromAddr - } - - if clientAddr.IP.Equal(fromAddr.IP) && clientAddr.Port == fromAddr.Port { - // from client - atyp := buf[3] - targetAddr := "" - var targetPort uint16 = 0 - var payload []byte - var header []byte - var targetIP net.IP - if atyp == 1 { - isDomain = false - targetAddr = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7]) - targetIP = net.ParseIP(targetAddr) - targetPort = binary.BigEndian.Uint16(buf[8:10]) - payload = buf[10:n] - header = buf[0:10] - } else if atyp == 3 { - isDomain = true - addrLen := int(buf[4]) - targetAddr = string(buf[5 : 5+addrLen]) - targetIPaddr, err := net.ResolveIPAddr("ip", targetAddr) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][UDP] ResolveIPAddr failed: %s", client.RemoteAddr().String(), err.Error())) - break - } - targetIP = targetIPaddr.IP - targetPort = binary.BigEndian.Uint16(buf[5+addrLen : 5+addrLen+2]) - payload = buf[5+addrLen+2 : n] - header = buf[0 : 5+addrLen+2] - } else if atyp == 4 { - logrus.Error(fmt.Sprintf("[%s][UDP] IPv6: no supported yet", client.RemoteAddr().String())) - break - } else { - logrus.Error(fmt.Sprintf("[%s][UDP] invalid atyp", client.RemoteAddr().String())) - break - } - // targetAddrPort := fmt.Sprintf("%s:%d", targetAddr, targetPort) - remoteAddr := &net.UDPAddr{IP: targetIP, Port: int(targetPort)} - udpPortMap[remoteAddr.String()] = make([]byte, len(header)) - copy(udpPortMap[remoteAddr.String()], header) - udpserver.SetWriteDeadline(time.Now().Add(time.Second * 10)) - if _, err = udpserver.WriteToUDP(payload, remoteAddr); err != nil { - logrus.Debug(fmt.Sprintf("[%s][UDP] WriteToUDP to remote failed: %s", client.RemoteAddr().String(), err.Error())) - continue - } - } else { - // from remote - header := udpPortMap[fromAddr.String()] - if header == nil { - logrus.Error(fmt.Sprintf("[%s][UDP] udpPortMap invalid header", client.RemoteAddr().String())) - continue - } - // header + body - if isDomain { - header = header[0:4] - } - body := append(header, buf[:n]...) - if _, err = udpserver.WriteToUDP(body, clientAddr); err != nil { - logrus.Debug(fmt.Sprintf("[%s][UDP] WriteToUDP to client failed: %s", client.RemoteAddr().String(), err.Error())) - continue - } - } - } - udpserver.Close() -} - -func Socks5Connect(client net.Conn, destAddrPort string) (target net.Conn, err error) { - logrus.Debug(fmt.Sprintf("Connecting %s", destAddrPort)) - dest, err := net.Dial("tcp", destAddrPort) - if err != nil { - return nil, err - } - logrus.Debug(fmt.Sprintf("Connected %s", destAddrPort)) - _, err = client.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - if err != nil { - dest.Close() - return nil, err - } - return dest, nil -} - -func ParseSocks5Request(client net.Conn) (destAddrPort string, err error) { - buf := make([]byte, 256) - if _, err = io.ReadFull(client, buf[:4]); err != nil { - err = fmt.Errorf("read header: %w", err) - return - } - ver, cmd, _, atyp := buf[0], buf[1], buf[2], buf[3] - if ver != 5 { - err = ErrInvalidSocksVersion - return - } - if cmd == 3 { - return "", errors.New("UDP Associate") - } else if cmd != 1 { - err = ErrInvalidSocksCmd - return - } - var addr string - switch atyp { - case 1: - if _, err = io.ReadFull(client, buf[:4]); err != nil { - err = fmt.Errorf("invalid IPv4: %w", err) - return - } - addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]) - case 3: - if _, err = io.ReadFull(client, buf[:1]); err != nil { - err = fmt.Errorf("invalid hostname: %w", err) - return - } - addrLen := int(buf[0]) - if _, err = io.ReadFull(client, buf[:addrLen]); err != nil { - err = fmt.Errorf("invalid hostname: %w", err) - return - } - addr = string(buf[:addrLen]) - case 4: - err = errors.New("IPv6: no supported yet") - return - default: - err = errors.New("invalid atyp") - return - } - - if _, err = io.ReadFull(client, buf[:2]); err != nil { - err = fmt.Errorf("read port: %w", err) - return - } - port := binary.BigEndian.Uint16(buf[:2]) - destAddrPort = fmt.Sprintf("%s:%d", addr, port) - return destAddrPort, nil -} - -func Socks5Forward(client, target net.Conn, destAddrPort string) { - forward := func(src, dest net.Conn) { - defer src.Close() - defer dest.Close() - io.Copy(src, dest) - } - - gforward := func(dst, src net.Conn) { - defer dst.Close() - defer src.Close() - transfer(dst, src, destAddrPort) - } - - go forward(client, target) - if cache.Contains(destAddrPort) { - logrus.Debug(fmt.Sprintf("Hit LRU Relay Cache: %s", destAddrPort)) - go forward(target, client) - } else { - go gforward(target, client) - } -} - -func isHTTP(reader *bufio.Reader) (bool, error) { - buf, err := reader.Peek(7) - if err != nil { - if strings.Contains(err.Error(), "EOF") { - logrus.Debug(fmt.Sprintf("Peek EOF: %s", err.Error())) - } else { - logrus.Error(fmt.Sprintf("Peek error: %s", err.Error())) - } - return false, err - } - hint := string(buf) - is_http := false - for _, v := range HTTP_METHOD { - if strings.HasPrefix(hint, v) { - is_http = true - break - } - } - return is_http, nil -} - -func buildNewUA(originUA string, targetUA string, uaRegexp *regexp2.Regexp, enablePartialReplace bool) string { - if enablePartialReplace && uaRegexp != nil { - newUaHearder, err := uaRegexp.Replace(originUA, targetUA, -1, -1) - if err != nil { - logrus.Error(fmt.Sprintf("User-Agent Replace Error: %s", err.Error())) - return targetUA - } - return newUaHearder - } - return targetUA -} - -func transfer(dst net.Conn, src net.Conn, destAddrPort string) { - srcReader := bufio.NewReader(src) - is_http, err := isHTTP(srcReader) - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - logrus.Warn(fmt.Sprintf("[%s] isHTTP error: %s", destAddrPort, err.Error())) - return - } - } - if !is_http && err == nil { - cache.Add(destAddrPort, destAddrPort) - logrus.Debug(fmt.Sprintf("Not HTTP, Add LRU Relay Cache: %s, Cache Len: %d", destAddrPort, cache.Len())) - io.Copy(dst, srcReader) - return - } - for { - request, err := http.ReadRequest(srcReader) - if err != nil { - if err == io.EOF { - logrus.Debug(fmt.Sprintf("[%s][%s] read EOF in first phase", destAddrPort, src.(*net.TCPConn).RemoteAddr().String())) - } else if strings.Contains(err.Error(), "use of closed network connection") { - logrus.Debug(fmt.Sprintf("[%s][%s] read closed in first phase: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - } else if strings.Contains(err.Error(), "connection reset by peer") { - logrus.Debug(fmt.Sprintf("[%s][%s] read reset in first phase: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - } else if strings.Contains(err.Error(), "i/o timeout") { - logrus.Debug(fmt.Sprintf("[%s][%s] read timeout in first phase: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - } else { - logrus.Error(fmt.Sprintf("[%s][%s] read error in first phase: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - } - return - } - uaStr := request.Header.Get("User-Agent") - if uaStr == "" { - cache.Add(destAddrPort, destAddrPort) - logrus.Debug(fmt.Sprintf("[%s] Not found User-Agent, Add LRU Relay Cache, Cache Len: %d", destAddrPort, cache.Len())) - if err = request.Write(dst); err != nil { - logrus.Error(fmt.Sprintf("[%s][%s] write error: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - } - io.Copy(dst, srcReader) - return - } - isInWhiteList := false - isMatchUaPattern := true - if uaPattern != "" { - isMatchUaPattern, err = uaRegexp.MatchString(uaStr) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][%s] User-Agent Regex Pattern Match Error: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - isMatchUaPattern = true - } - } - for _, v := range whitelist { - if v == uaStr { - isInWhiteList = true - break - } - } - if isInWhiteList || !isMatchUaPattern { - if !isMatchUaPattern { - logrus.Debug(fmt.Sprintf("[%s][%s] Not Hit User-Agent Pattern: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), uaStr)) - } - if isInWhiteList { - logrus.Debug(fmt.Sprintf("[%s][%s] Hit User-Agent Whitelist: %s, Add LRU Relay Cache, Cache Len: %d", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), uaStr, cache.Len())) - cache.Add(destAddrPort, destAddrPort) - } - err = request.Write(dst) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][%s] write error: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - break - } - io.Copy(dst, srcReader) - return - } - logrus.Debug(fmt.Sprintf("[%s][%s] Hit User-Agent: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), uaStr)) - mockedUA := buildNewUA(uaStr, payload, uaRegexp, enablePartialReplace) - request.Header.Set("User-Agent", mockedUA) - err = request.Write(dst) - if err != nil { - logrus.Error(fmt.Sprintf("[%s][%s] write error after replace user-agent: %s", destAddrPort, src.(*net.TCPConn).RemoteAddr().String(), err.Error())) - break - } - statistics.AddStat(&statistics.StatRecord{ - Host: destAddrPort, - OriginUA: uaStr, - MockedUA: mockedUA, - }) + srv := socks5.New(cfg, rw) + log.LogHeader(version, srv.ListenAddr, cfg) + if err := srv.Start(); err != nil { + logrus.Fatal(err) } }