diff --git a/src/go.mod b/src/go.mod index fcd6941..0584ce6 100644 --- a/src/go.mod +++ b/src/go.mod @@ -10,4 +10,9 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) -require github.com/stretchr/testify v1.8.2 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/src/go.sum b/src/go.sum index 9920d77..e08cf79 100644 --- a/src/go.sum +++ b/src/go.sum @@ -10,16 +10,13 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= diff --git a/src/internal/log/log.go b/src/internal/log/log.go index b2ed512..e5c4174 100644 --- a/src/internal/log/log.go +++ b/src/internal/log/log.go @@ -72,3 +72,19 @@ func LogHeader(version string, cfg *config.Config) { logrus.Infof("Enable Partial Replace: %v", cfg.EnablePartialReplace) logrus.Infof("Log level: %s", cfg.LogLevel) } + +func LogDebugWithAddr(src string, dest string, msg string) { + logrus.Debugf("[%s -> %s] %s", src, dest, msg) +} + +func LogInfoWithAddr(src string, dest string, msg string) { + logrus.Infof("[%s -> %s] %s", src, dest, msg) +} + +func LogWarnWithAddr(src string, dest string, msg string) { + logrus.Warnf("[%s -> %s] %s", src, dest, msg) +} + +func LogErrorWithAddr(src string, dest string, msg string) { + logrus.Errorf("[%s -> %s] %s", src, dest, msg) +} diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go index ffe2d14..7a52193 100644 --- a/src/internal/rewrite/rewriter.go +++ b/src/internal/rewrite/rewriter.go @@ -2,7 +2,9 @@ package rewrite import ( "bufio" + "bytes" "errors" + "fmt" "io" "net" "net/http" @@ -14,6 +16,7 @@ import ( "github.com/sirupsen/logrus" "github.com/sunbk201/ua3f/internal/config" + "github.com/sunbk201/ua3f/internal/log" "github.com/sunbk201/ua3f/internal/statistics" ) @@ -24,7 +27,17 @@ const ( ) // HTTP methods used to detect HTTP by request line. -var httpMethods = []string{"GET", "POST", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", "CONNECT"} +var httpMethods = map[string]struct{}{ + "GET": {}, + "POST": {}, + "HEAD": {}, + "PUT": {}, + "PATCH": {}, + "DELETE": {}, + "OPTIONS": {}, + "TRACE": {}, + "CONNECT": {}, +} // Hardcoded whitelist of UAs that should be left untouched. var defaultWhitelist = []string{ @@ -74,56 +87,70 @@ func New(cfg *config.Config) (*Rewriter, error) { // - If target in LRU cache: pass-through (raw). // - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched). // - Else: mark target in LRU and pass-through. -func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort string) error { +func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (err error) { + srcAddr := src.RemoteAddr().String() + // Fast path: known pass-through - if r.cache.Contains(destAddrPort) { - logrus.Debugf("Hit LRU Relay Cache: %s", destAddrPort) - _, _ = io.Copy(dst, src) + if r.cache.Contains(destAddr) { + log.LogDebugWithAddr(src.RemoteAddr().String(), destAddr, "LRU Relay Cache Hit, pass-through") + io.Copy(dst, src) return nil } reader := bufio.NewReader(src) + defer func() { + if err != nil { + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw Error: %s", err.Error())) + } + io.Copy(dst, reader) + }() isHTTP, err := r.isHTTP(reader) if err != nil { - if strings.Contains(err.Error(), ErrUseClosedConn) { - logrus.Warnf("[%s] isHTTP error: %s", destAddrPort, err.Error()) - return err - } - // Other read errors terminate the direction. - return err + err = fmt.Errorf("isHTTP: %w", err) + return } - if !isHTTP { - r.cache.Add(destAddrPort, destAddrPort) - logrus.Debugf("Not HTTP, Add LRU Relay Cache: %s, Cache Len: %d", destAddrPort, r.cache.Len()) - _, _ = io.Copy(dst, reader) - return nil + r.cache.Add(destAddr, destAddr) + log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") + return } - srcAddr := src.RemoteAddr().String() + var req *http.Request // HTTP request loop (handles keep-alive) for { - req, err := http.ReadRequest(reader) + isHTTP, err = r.isHTTP(reader) if err != nil { - r.logReadErr(destAddrPort, src, err) - return err + err = fmt.Errorf("isHTTP: %w", err) + return + } + if !isHTTP { + h2, _ := reader.Peek(2) // ensure we have at least 2 bytes + if isWebSocket(h2) { + log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through") + } else { + r.cache.Add(destAddr, destAddr) + log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") + } + return + } + req, err = http.ReadRequest(reader) + if err != nil { + err = fmt.Errorf("http.ReadRequest: %w", err) + return } originalUA := req.Header.Get("User-Agent") // No UA header: pass-through after writing this first request if originalUA == "" { - r.cache.Add(destAddrPort, destAddrPort) - logrus.Debugf("[%s] Not found User-Agent, Add LRU Relay Cache, Cache Len: %d", - destAddrPort, r.cache.Len()) - if err := req.Write(dst); err != nil { - logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error()) - return err + r.cache.Add(destAddr, destAddr) + log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache") + if err = req.Write(dst); err != nil { + err = fmt.Errorf("req.Write: %w", err) } - _, _ = io.Copy(dst, reader) - return nil + return } isWhitelist := r.inWhitelist(originalUA) @@ -131,8 +158,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin if r.pattern != "" { matches, err = r.uaRegex.MatchString(originalUA) if err != nil { - logrus.Errorf("[%s][%s] User-Agent Regex Pattern Match Error: %s", - destAddrPort, srcAddr, err.Error()) + log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error())) matches = true } } @@ -140,38 +166,33 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin // If UA is whitelisted or does not match target pattern, write once then pass-through. if isWhitelist || !matches { if !matches { - logrus.Debugf("[%s][%s] Not Hit User-Agent Pattern: %s", - destAddrPort, srcAddr, originalUA) + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA)) } if isWhitelist { - logrus.Debugf("[%s][%s] Hit User-Agent Whitelist: %s, Add LRU Relay Cache, Cache Len: %d", - destAddrPort, srcAddr, originalUA, r.cache.Len()) - r.cache.Add(destAddrPort, destAddrPort) + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA)) + r.cache.Add(destAddr, destAddr) } statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ - Host: destAddrPort, + Host: destAddr, UA: originalUA, }) - if err := req.Write(dst); err != nil { - logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error()) - return err + if err = req.Write(dst); err != nil { + err = fmt.Errorf("req.Write: %w", err) } - _, _ = io.Copy(dst, reader) - return nil + return } // Rewrite UA and forward the request (including body) - logrus.Debugf("[%s][%s] Hit User-Agent: %s", destAddrPort, srcAddr, originalUA) + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent: %s", originalUA)) mockedUA := r.buildNewUA(originalUA) req.Header.Set("User-Agent", mockedUA) - if err := req.Write(dst); err != nil { - logrus.Errorf("[%s][%s] write error after replace user-agent: %s", - destAddrPort, srcAddr, err.Error()) - return err + if err = req.Write(dst); err != nil { + err = fmt.Errorf("req.Write: %w", err) + return } statistics.AddRewriteRecord(&statistics.RewriteRecord{ - Host: destAddrPort, + Host: destAddr, OriginalUA: originalUA, MockedUA: mockedUA, }) @@ -180,20 +201,32 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin // isHTTP peeks the first few bytes and checks for a known HTTP method prefix. func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) { - buf, err := reader.Peek(7) + // Fast check: peek first word to see if it's a known HTTP method + const maxMethodLen = 7 + hintSlice, err := reader.Peek(maxMethodLen) if err != nil { - if strings.Contains(err.Error(), "EOF") { - logrus.Debugf("Peek EOF: %s", err.Error()) - } else { - logrus.Errorf("Peek error: %s", err.Error()) - } return false, err } - hint := string(buf) - for _, m := range httpMethods { - if strings.HasPrefix(hint, m) { - return true, nil - } + hint := string(hintSlice) + method, _, _ := strings.Cut(hint, " ") + if _, exists := httpMethods[method]; !exists { + return false, nil + } + + // Detailed check: parse request line + line, err := peekLineString(reader) + if err != nil { + return false, err + } + method, _, proto, ok := parseRequestLine(line) + if !ok { + return false, nil + } + if proto != "HTTP/1.1" && proto != "HTTP/1.0" { + return false, nil + } + if _, exists := httpMethods[method]; exists { + return true, nil } return false, nil } @@ -216,18 +249,104 @@ func (r *Rewriter) inWhitelist(ua string) bool { return ok } -func (r *Rewriter) logReadErr(destAddrPort string, src net.Conn, err error) { - remote := src.RemoteAddr().String() - switch { - case errors.Is(err, io.EOF): - logrus.Debugf("[%s][%s] read EOF in first phase", destAddrPort, remote) - case strings.Contains(err.Error(), ErrUseClosedConn): - logrus.Debugf("[%s][%s] read closed in first phase: %s", destAddrPort, remote, err.Error()) - case strings.Contains(err.Error(), ErrConnResetByPeer): - logrus.Debugf("[%s][%s] read reset in first phase: %s", destAddrPort, remote, err.Error()) - case strings.Contains(err.Error(), ErrIOTimeout): - logrus.Debugf("[%s][%s] read timeout in first phase: %s", destAddrPort, remote, err.Error()) - default: - logrus.Errorf("[%s][%s] read error in first phase: %s", destAddrPort, remote, err.Error()) +// peekLineSlice reads a line from bufio.Reader without consuming it. +// returns the line bytes (without CRLF) or error. +func peekLineSlice(br *bufio.Reader) ([]byte, error) { + const chunkSize = 256 + var line []byte + + offset := 0 + for { + // Ensure there is data in the buffer + n := br.Buffered() + if n == 0 { + // No data in buffer, try to fill it + _, err := br.Peek(1) + if err != nil { + return nil, err + } + n = br.Buffered() + } + + // Limit to chunkSize + if n > chunkSize { + n = chunkSize + } + + buf, err := br.Peek(offset + n) + if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) { + return nil, err + } + + data := buf[offset:] + if i := bytes.IndexByte(data, '\n'); i >= 0 { + line = append(line, data[:i]...) + // Remove trailing CR if present + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + return line, nil + } + + line = append(line, data...) + offset += len(data) + + // If EOF reached without finding a newline, return what we have + if errors.Is(err, io.EOF) { + return line, io.EOF + } } } + +// peekLineString reads a line from bufio.Reader without consuming it. +// returns the line string (without CRLF) or error. +func peekLineString(br *bufio.Reader) (string, error) { + lineBytes, err := peekLineSlice(br) + if err != nil { + return "", err + } + return string(lineBytes), nil +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false + } + return method, requestURI, proto, true +} + +func isWebSocket(header []byte) bool { + if len(header) < 2 { + return false + } + + b0 := header[0] + b1 := header[1] + + rsv := b0 & 0x70 // RSV1-3 + opcode := b0 & 0x0F // opcode + mask := b1 & 0x80 // MASK + + // requested frames from client to server must be masked + if mask == 0 { + return false + } + // Control frames must have FIN set + if rsv != 0 { + return false + } + // opcode must be in valid range + if opcode > 0xA { + return false + } + // payload length + payloadLen := b1 & 0x7F + if payloadLen > 0 && payloadLen <= 125 { + return true + } + + return true +}