From 0a01c9f04f2c03c0c1f95122a4504d5b3aa08f78 Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Tue, 4 Nov 2025 22:18:06 +0800 Subject: [PATCH] refactor: streamline rewriting --- src/internal/rewrite/rewriter.go | 431 ++++++++------------------ src/internal/rewrite/rewriter_test.go | 29 +- src/internal/server/http/http.go | 14 +- src/internal/server/utils/tcp.go | 2 +- src/internal/sniff/common.go | 75 +++++ src/internal/sniff/http.go | 75 +++++ src/internal/sniff/tls.go | 25 ++ src/internal/sniff/websocket.go | 34 ++ 8 files changed, 353 insertions(+), 332 deletions(-) create mode 100644 src/internal/sniff/common.go create mode 100644 src/internal/sniff/http.go create mode 100644 src/internal/sniff/tls.go create mode 100644 src/internal/sniff/websocket.go diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go index 95cd637..5f25e7a 100644 --- a/src/internal/rewrite/rewriter.go +++ b/src/internal/rewrite/rewriter.go @@ -2,8 +2,6 @@ package rewrite import ( "bufio" - "bytes" - "errors" "fmt" "io" "net" @@ -17,44 +15,19 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/log" + "github.com/sunbk201/ua3f/internal/sniff" "github.com/sunbk201/ua3f/internal/statistics" ) -const ( - ErrUseClosedConn = "use of closed network connection" - ErrConnResetByPeer = "connection reset by peer" - ErrIOTimeout = "i/o timeout" -) - -// HTTP methods used to detect HTTP by request line. -var httpMethods = map[string]struct{}{ - "GET": {}, - "POST": {}, - "HEAD": {}, - "PUT": {}, - "PATCH": {}, - "DELETE": {}, - "OPTIONS": {}, - "TRACE": {}, - "CONNECT": {}, -} - -// Hardcoded whitelist of UAs that should be left untouched. -var defaultWhitelist = []string{ - "MicroMessenger Client", - "ByteDancePcdn", - "Go-http-client/1.1", -} - // Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache. type Rewriter struct { - payloadUA string - pattern string - enablePartialReplace bool + payloadUA string + pattern string + partialReplace bool uaRegex *regexp2.Regexp - Cache *expirable.LRU[string, string] - whitelist map[string]struct{} + whitelist []string + Cache *expirable.LRU[string, struct{}] } // New constructs a Rewriter from config. Compiles regex and allocates cache. @@ -66,167 +39,33 @@ func New(cfg *config.Config) (*Rewriter, error) { return nil, err } - cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute) - - whitelist := make(map[string]struct{}, len(defaultWhitelist)) - for _, s := range defaultWhitelist { - whitelist[s] = struct{}{} - } - return &Rewriter{ - payloadUA: cfg.PayloadUA, - pattern: cfg.UAPattern, - enablePartialReplace: cfg.EnablePartialReplace, - uaRegex: uaRegex, - Cache: cache, - whitelist: whitelist, + payloadUA: cfg.PayloadUA, + pattern: cfg.UAPattern, + partialReplace: cfg.EnablePartialReplace, + uaRegex: uaRegex, + Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + whitelist: []string{ + "MicroMessenger Client", + "Bilibili Freedoooooom/MarkII", + "Go-http-client/1.1", + "ByteDancePcdn", + }, }, nil } -// RewriteAndForward rewrites the User-Agent header if needed and forwards the request. -// Returns pass=true if the request has been forwarded as-is (no rewrite). -func (r *Rewriter) RewriteAndForward(dst net.Conn, req *http.Request, destAddr string, srcAddr string) (pass bool, err error) { - - originalUA := req.Header.Get("User-Agent") - - // No UA header: pass-through after writing this first request - if originalUA == "" { - r.Cache.Add(destAddr, destAddr) - log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache") - if err = req.Write(dst); err != nil { - err = fmt.Errorf("req.Write: %w", err) - } - pass = true - return - } - - log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: %s", originalUA)) - - isWhitelist := r.inWhitelist(originalUA) - matches := true - if r.pattern != "" { - matches, err = r.uaRegex.MatchString(originalUA) - if err != nil { - log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error())) - matches = true +func (r *Rewriter) inWhitelist(ua string) bool { + for _, w := range r.whitelist { + if w == ua { + return true } } - - // If UA is whitelisted or does not match target pattern, write once then pass-through. - if isWhitelist || !matches { - if !matches { - log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA)) - } - if isWhitelist { - log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA)) - r.Cache.Add(destAddr, destAddr) - } - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ - Host: destAddr, - UA: originalUA, - }) - if err = req.Write(dst); err != nil { - err = fmt.Errorf("req.Write: %w", err) - } - pass = true - return - } - - // Rewrite UA and forward the request (including body) - rewritedUA := r.buildNewUA(originalUA) - log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA)) - req.Header.Set("User-Agent", rewritedUA) - if err = req.Write(dst); err != nil { - err = fmt.Errorf("req.Write: %w", err) - pass = true - return - } - - statistics.AddRewriteRecord(&statistics.RewriteRecord{ - Host: destAddr, - OriginalUA: originalUA, - MockedUA: rewritedUA, - }) - - return false, nil + return false } -// ProxyHTTPOrRaw reads traffic from src and writes to dst. -// - If target in LRU cache: pass-through (raw). -// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched). -// - Else: mark target in LRU and pass-through. -func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) { - reader := bufio.NewReader(src) - - defer func() { - if err != nil { - log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error())) - } - io.Copy(dst, reader) - }() - - if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) { - r.Cache.Add(destAddr, destAddr) - log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward") - return - } - isHTTP, err := r.isHTTP(reader) - if err != nil { - err = fmt.Errorf("isHTTP: %w", err) - return - } - if !isHTTP { - r.Cache.Add(destAddr, destAddr) - log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") - return - } - - var req *http.Request - var pass bool - - // HTTP request loop (handles keep-alive) - for { - if isHTTP, err = r.isHTTP(reader); err != nil { - err = fmt.Errorf("isHTTP: %w", err) - return - } - if !isHTTP { - h2, _ := reader.Peek(2) // ensure we have at least 2 bytes - if isWebSocket(h2) { - log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through") - } else { - r.Cache.Add(destAddr, destAddr) - log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") - } - return - } - if req, err = http.ReadRequest(reader); err != nil { - err = fmt.Errorf("http.ReadRequest: %w", err) - return - } - if pass, err = r.RewriteAndForward(dst, req, destAddr, srcAddr); pass { - return - } - } -} - -// isHTTP peeks the first few bytes and checks for a known HTTP method prefix. -func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) { - // Fast check: peek first word to see if it's a known HTTP method - const maxMethodLen = 7 - hintSlice, err := reader.Peek(maxMethodLen) - if err != nil { - return false, err - } - hint := string(hintSlice) - method, _, _ := strings.Cut(hint, " ") - _, exists := httpMethods[method] - return exists, nil -} - -// buildNewUA returns either a partial replacement (regex) or full overwrite. -func (r *Rewriter) buildNewUA(originUA string) string { - if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" { +// buildUserAgent returns either a partial replacement (regex) or full overwrite. +func (r *Rewriter) buildUserAgent(originUA string) string { + if r.partialReplace && r.uaRegex != nil && r.pattern != "" { newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1) if err != nil { logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error()) @@ -237,131 +76,117 @@ func (r *Rewriter) buildNewUA(originUA string) string { return r.payloadUA } -func (r *Rewriter) inWhitelist(ua string) bool { - _, ok := r.whitelist[ua] - return ok +func (r *Rewriter) ShouldRewrite(req *http.Request, srcAddr, destAddr string) bool { + originalUA := req.Header.Get("User-Agent") + log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: (%s)", originalUA)) + + var err error + matches := true + isWhitelist := r.inWhitelist(originalUA) + + if !isWhitelist && r.pattern != "" { + matches, err = r.uaRegex.MatchString(originalUA) + if err != nil { + log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Match Error: %s", err.Error())) + matches = true + } + } + if isWhitelist { + log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA)) + r.Cache.Add(destAddr, struct{}{}) + } + if !matches { + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA)) + } + + hit := !isWhitelist && matches + if !hit { + statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + Host: destAddr, + UA: originalUA, + }) + } + return hit } -// peekLineSlice reads a line from bufio.Reader without consuming it. -// returns the line bytes (without CRLF) or error. -func peekLineSlice(br *bufio.Reader) ([]byte, error) { - const chunkSize = 256 - var line []byte +func (r *Rewriter) Rewrite(req *http.Request, srcAddr string, destAddr string) *http.Request { + originalUA := req.Header.Get("User-Agent") + rewritedUA := r.buildUserAgent(originalUA) + req.Header.Set("User-Agent", rewritedUA) + + log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA)) + + statistics.AddRewriteRecord(&statistics.RewriteRecord{ + Host: destAddr, + OriginalUA: originalUA, + MockedUA: rewritedUA, + }) + return req +} + +func (r *Rewriter) Forward(dst net.Conn, req *http.Request) error { + if err := req.Write(dst); err != nil { + return fmt.Errorf("req.Write: %w", err) + } + req.Body.Close() + return nil +} + +// Process handles the proxying with UA rewriting logic. +func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) { + reader := bufio.NewReader(src) + + defer func() { + if err != nil { + log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Process: %s", err.Error())) + } + io.Copy(dst, reader) + }() + + if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) { + r.Cache.Add(destAddr, struct{}{}) + log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward") + return + } + + var isHTTP bool + + if isHTTP, err = sniff.SniffHTTP(reader); err != nil { + err = fmt.Errorf("sniff.SniffHTTP: %w", err) + return + } + if !isHTTP { + r.Cache.Add(destAddr, struct{}{}) + log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to cache") + return + } + + var req *http.Request - offset := 0 for { - // Ensure there is data in the buffer - n := br.Buffered() - if n == 0 { - // No data in buffer, try to fill it - _, err := br.Peek(1) - if err != nil { - return nil, err - } - n = br.Buffered() + if isHTTP, err = sniff.SniffHTTPFast(reader); err != nil { + err = fmt.Errorf("isHTTP: %w", err) + return } - - // Limit to chunkSize - if n > chunkSize { - n = chunkSize + if !isHTTP { + r.Cache.Add(destAddr, struct{}{}) + log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") + return } - - buf, err := br.Peek(offset + n) - if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) { - return nil, err + if req, err = http.ReadRequest(reader); err != nil { + err = fmt.Errorf("http.ReadRequest: %w", err) + return } - - data := buf[offset:] - if i := bytes.IndexByte(data, '\n'); i >= 0 { - line = append(line, data[:i]...) - // Remove trailing CR if present - if len(line) > 0 && line[len(line)-1] == '\r' { - line = line[:len(line)-1] - } - return line, nil + if r.ShouldRewrite(req, srcAddr, destAddr) { + req = r.Rewrite(req, srcAddr, destAddr) } - - line = append(line, data...) - offset += len(data) - - // If EOF reached without finding a newline, return what we have - if errors.Is(err, io.EOF) { - return line, io.EOF + if err = r.Forward(dst, req); err != nil { + err = fmt.Errorf("r.forward: %w", err) + return + } + if req.Header.Get("Upgrade") == "websocket" && req.Header.Get("Connection") == "Upgrade" { + log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket Upgrade detected, switching to raw proxy") + return } } } - -// peekLineString reads a line from bufio.Reader without consuming it. -// returns the line string (without CRLF) or error. -func peekLineString(br *bufio.Reader) (string, error) { - lineBytes, err := peekLineSlice(br) - if err != nil { - return "", err - } - return string(lineBytes), nil -} - -// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. -func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { - method, rest, ok1 := strings.Cut(line, " ") - requestURI, proto, ok2 := strings.Cut(rest, " ") - if !ok1 || !ok2 { - return "", "", "", false - } - return method, requestURI, proto, true -} - -func isWebSocket(header []byte) bool { - if len(header) < 2 { - return false - } - - b0 := header[0] - b1 := header[1] - - rsv := b0 & 0x70 // RSV1-3 - opcode := b0 & 0x0F // opcode - mask := b1 & 0x80 // MASK - - // requested frames from client to server must be masked - if mask == 0 { - return false - } - // Control frames must have FIN set - if rsv != 0 { - return false - } - // opcode must be in valid range - if opcode > 0xA { - return false - } - // payload length - payloadLen := b1 & 0x7F - if payloadLen > 0 && payloadLen <= 125 { - return true - } - - return true -} - -func isTLSClientHello(reader *bufio.Reader) bool { - header, err := reader.Peek(3) - if err != nil { - return false - } - // TLS record type 0x16 = Handshake - if header[0] != 0x16 { - return false - } - // TLS version - versionMajor := header[1] - versionMinor := header[2] - if versionMajor != 0x03 { - return false - } - if versionMinor < 0x01 || versionMinor > 0x04 { - return false - } - - return true -} diff --git a/src/internal/rewrite/rewriter_test.go b/src/internal/rewrite/rewriter_test.go index 9d19d1e..50d3962 100644 --- a/src/internal/rewrite/rewriter_test.go +++ b/src/internal/rewrite/rewriter_test.go @@ -1,7 +1,6 @@ package rewrite import ( - "bufio" "bytes" "io" "net" @@ -46,35 +45,11 @@ func TestNewRewriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, cfg.PayloadUA, rewriter.payloadUA) assert.Equal(t, cfg.UAPattern, rewriter.pattern) - assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace) + assert.Equal(t, cfg.EnablePartialReplace, rewriter.partialReplace) assert.NotNil(t, rewriter.uaRegex) assert.NotNil(t, rewriter.Cache) } -func TestIsHTTP(t *testing.T) { - r := newTestRewriter(t) - - tests := []struct { - name string - input string - expected bool - }{ - {"HTTP Get", "GET / HTTP/1.1\r\n", true}, - {"HTTP Post", "POST /test HTTP/1.1\r\n", true}, - {"HTTP Connect", "CONNECT example.com:443 HTTP/1.1\r\n", true}, - {"Not HTTP", "HELLO WORLD\r\n", false}, - {"Not HTTP", "SSH-2.0-OpenSSH_8.4\r\n", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reader := bufio.NewReader(strings.NewReader(tt.input)) - isHTTP, _ := r.isHTTP(reader) - assert.Equal(t, tt.expected, isHTTP) - }) - } -} - func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) { r := newTestRewriter(t) @@ -83,7 +58,7 @@ func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) { dstBuf := &bytes.Buffer{} dst := &mockConn{Reader: nil, Writer: dstBuf} - r.ProxyHTTPOrRaw(dst, src, "example.com:80", "srcAddr") + r.Process(dst, src, "example.com:80", "srcAddr") out := dstBuf.String() assert.Contains(t, out, "User-Agent: MockUA/1.0") diff --git a/src/internal/server/http/http.go b/src/internal/server/http/http.go index a6b8851..5255764 100644 --- a/src/internal/server/http/http.go +++ b/src/internal/server/http/http.go @@ -58,7 +58,7 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { } defer target.Close() - _, err = s.rw.RewriteAndForward(target, req, req.Host, req.RemoteAddr) + err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -101,6 +101,18 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { s.ForwardTCP(client, dest, destAddr) } +func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) { + rw := s.rw + if rw.ShouldRewrite(req, srcAddr, dstAddr) { + req = rw.Rewrite(req, srcAddr, dstAddr) + } + if err = rw.Forward(target, req); err != nil { + err = fmt.Errorf("r.forward: %w", err) + return + } + return nil +} + func (s *Server) HandleClient(client net.Conn) { } diff --git a/src/internal/server/utils/tcp.go b/src/internal/server/utils/tcp.go index 0ad90fd..09890b7 100644 --- a/src/internal/server/utils/tcp.go +++ b/src/internal/server/utils/tcp.go @@ -60,7 +60,7 @@ func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddr string) { io.Copy(dst, src) return } - _ = rw.ProxyHTTPOrRaw(dst, src, destAddr, srcAddr) + _ = rw.Process(dst, src, destAddr, srcAddr) } func GetConnFD(conn net.Conn) (fd int, err error) { diff --git a/src/internal/sniff/common.go b/src/internal/sniff/common.go new file mode 100644 index 0000000..87a1dbc --- /dev/null +++ b/src/internal/sniff/common.go @@ -0,0 +1,75 @@ +package sniff + +import ( + "bufio" + "bytes" + "errors" + "io" + "time" +) + +var ErrPeekTimeout = errors.New("peek timeout") + +// peekLineSlice reads a line from bufio.Reader without consuming it. +// returns the line bytes (without CRLF) or error. +func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) { + var line []byte + + peekSize := maxSize + if peekSize == 0 { + return nil, io.EOF + } + if buffered := br.Buffered(); buffered < peekSize { + peekSize = buffered + } + + buf, err := br.Peek(peekSize) + if err != nil { + return nil, err + } + + if i := bytes.IndexByte(buf, '\n'); i >= 0 { + line = append(line, buf[:i]...) + // Remove trailing CR if present + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + return line, nil + } + return nil, io.EOF +} + +// peekLineString reads a line from bufio.Reader without consuming it. +// returns the line string (without CRLF) or error. +func peekLineString(br *bufio.Reader, maxSize int) (string, error) { + lineBytes, err := peekLineSlice(br, maxSize) + if err != nil { + return "", err + } + return string(lineBytes), nil +} + +// PeekWithTimeout peeks n bytes from bufio.Reader with a timeout. +func PeekWithTimeout(r *bufio.Reader, n int, timeout time.Duration) ([]byte, error) { + if buffered := r.Buffered(); buffered >= n { + data, err := r.Peek(n) + return data, err + } + type result struct { + data []byte + err error + } + ch := make(chan result, 1) + + go func() { + data, err := r.Peek(n) + ch <- result{data, err} + }() + + select { + case res := <-ch: + return res.data, res.err + case <-time.After(timeout): + return nil, ErrPeekTimeout + } +} diff --git a/src/internal/sniff/http.go b/src/internal/sniff/http.go new file mode 100644 index 0000000..eb29db6 --- /dev/null +++ b/src/internal/sniff/http.go @@ -0,0 +1,75 @@ +package sniff + +import ( + "bufio" + "strings" + "time" +) + +// HTTP methods used to detect HTTP by request line. +var methods = [...]string{"GET", "POST", "HEAD", "CONNECT", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE"} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false + } + return method, requestURI, proto, true +} + +// beginWithHTTPMethod peeks the first few bytes to check for known HTTP method prefixes. +func beginWithHTTPMethod(reader *bufio.Reader) (bool, error) { + const maxMethodLen = 7 + const minMethodLen = 3 + var hint []byte + hint, err := PeekWithTimeout(reader, maxMethodLen, 3*time.Second) + if err != nil { + if err != ErrPeekTimeout { + return false, err + } + hint, err = PeekWithTimeout(reader, minMethodLen+1, time.Second) + if err != nil { + return false, err + } + } + method, _, _ := strings.Cut(string(hint), " ") + for _, m := range methods { + if method == m { + return true, nil + } + } + return false, nil +} + +// SniffHTTP peeks the first few bytes and checks for a known HTTP method prefix. +func SniffHTTP(reader *bufio.Reader) (bool, error) { + // Fast check: peek first word to see if it's a known HTTP method + beginHTTP, err := beginWithHTTPMethod(reader) + if err != nil { + return false, err + } + + // Detailed check: parse request line + line, err := peekLineString(reader, 128) + if err != nil { + return beginHTTP, nil + } + _, _, proto, ok := parseRequestLine(line) + if !ok { + return beginHTTP, nil + } + if proto != "HTTP/1.1" && proto != "HTTP/1.0" { + return false, nil + } + return beginHTTP, nil +} + +func SniffHTTPFast(reader *bufio.Reader) (bool, error) { + beginHTTP, err := beginWithHTTPMethod(reader) + if err != nil { + return false, err + } + return beginHTTP, nil +} diff --git a/src/internal/sniff/tls.go b/src/internal/sniff/tls.go new file mode 100644 index 0000000..83042fe --- /dev/null +++ b/src/internal/sniff/tls.go @@ -0,0 +1,25 @@ +package sniff + +import "bufio" + +func SniffTLSClientHello(reader *bufio.Reader) bool { + header, err := reader.Peek(3) + if err != nil { + return false + } + // TLS record type 0x16 = Handshake + if header[0] != 0x16 { + return false + } + // TLS version + versionMajor := header[1] + versionMinor := header[2] + if versionMajor != 0x03 { + return false + } + if versionMinor < 0x01 || versionMinor > 0x04 { + return false + } + + return true +} diff --git a/src/internal/sniff/websocket.go b/src/internal/sniff/websocket.go new file mode 100644 index 0000000..572005d --- /dev/null +++ b/src/internal/sniff/websocket.go @@ -0,0 +1,34 @@ +package sniff + +func SniffWebSocket(header []byte) bool { + if len(header) < 2 { + return false + } + + b0 := header[0] + b1 := header[1] + + rsv := b0 & 0x70 // RSV1-3 + opcode := b0 & 0x0F // opcode + mask := b1 & 0x80 // MASK + + // requested frames from client to server must be masked + if mask == 0 { + return false + } + // Control frames must have FIN set + if rsv != 0 { + return false + } + // opcode must be in valid range + if opcode > 0xA { + return false + } + // payload length + payloadLen := b1 & 0x7F + if payloadLen > 0 && payloadLen <= 125 { + return true + } + + return true +}