refactor: streamline rewriting

This commit is contained in:
SunBK201 2025-11-04 22:18:06 +08:00
parent 57a344e563
commit ab0cf0efdc
8 changed files with 353 additions and 332 deletions

View File

@ -2,8 +2,6 @@ package rewrite
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
@ -17,44 +15,19 @@ import (
"github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/log"
"github.com/sunbk201/ua3f/internal/sniff"
"github.com/sunbk201/ua3f/internal/statistics"
)
const (
ErrUseClosedConn = "use of closed network connection"
ErrConnResetByPeer = "connection reset by peer"
ErrIOTimeout = "i/o timeout"
)
// HTTP methods used to detect HTTP by request line.
var httpMethods = map[string]struct{}{
"GET": {},
"POST": {},
"HEAD": {},
"PUT": {},
"PATCH": {},
"DELETE": {},
"OPTIONS": {},
"TRACE": {},
"CONNECT": {},
}
// Hardcoded whitelist of UAs that should be left untouched.
var defaultWhitelist = []string{
"MicroMessenger Client",
"ByteDancePcdn",
"Go-http-client/1.1",
}
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
type Rewriter struct {
payloadUA string
pattern string
enablePartialReplace bool
payloadUA string
pattern string
partialReplace bool
uaRegex *regexp2.Regexp
Cache *expirable.LRU[string, string]
whitelist map[string]struct{}
whitelist []string
Cache *expirable.LRU[string, struct{}]
}
// New constructs a Rewriter from config. Compiles regex and allocates cache.
@ -66,167 +39,33 @@ func New(cfg *config.Config) (*Rewriter, error) {
return nil, err
}
cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute)
whitelist := make(map[string]struct{}, len(defaultWhitelist))
for _, s := range defaultWhitelist {
whitelist[s] = struct{}{}
}
return &Rewriter{
payloadUA: cfg.PayloadUA,
pattern: cfg.UAPattern,
enablePartialReplace: cfg.EnablePartialReplace,
uaRegex: uaRegex,
Cache: cache,
whitelist: whitelist,
payloadUA: cfg.PayloadUA,
pattern: cfg.UAPattern,
partialReplace: cfg.EnablePartialReplace,
uaRegex: uaRegex,
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
whitelist: []string{
"MicroMessenger Client",
"Bilibili Freedoooooom/MarkII",
"Go-http-client/1.1",
"ByteDancePcdn",
},
}, nil
}
// RewriteAndForward rewrites the User-Agent header if needed and forwards the request.
// Returns pass=true if the request has been forwarded as-is (no rewrite).
func (r *Rewriter) RewriteAndForward(dst net.Conn, req *http.Request, destAddr string, srcAddr string) (pass bool, err error) {
originalUA := req.Header.Get("User-Agent")
// No UA header: pass-through after writing this first request
if originalUA == "" {
r.Cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
}
pass = true
return
}
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: %s", originalUA))
isWhitelist := r.inWhitelist(originalUA)
matches := true
if r.pattern != "" {
matches, err = r.uaRegex.MatchString(originalUA)
if err != nil {
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
matches = true
func (r *Rewriter) inWhitelist(ua string) bool {
for _, w := range r.whitelist {
if w == ua {
return true
}
}
// If UA is whitelisted or does not match target pattern, write once then pass-through.
if isWhitelist || !matches {
if !matches {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
}
if isWhitelist {
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
r.Cache.Add(destAddr, destAddr)
}
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
Host: destAddr,
UA: originalUA,
})
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
}
pass = true
return
}
// Rewrite UA and forward the request (including body)
rewritedUA := r.buildNewUA(originalUA)
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
req.Header.Set("User-Agent", rewritedUA)
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
pass = true
return
}
statistics.AddRewriteRecord(&statistics.RewriteRecord{
Host: destAddr,
OriginalUA: originalUA,
MockedUA: rewritedUA,
})
return false, nil
return false
}
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
// - If target in LRU cache: pass-through (raw).
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
// - Else: mark target in LRU and pass-through.
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
reader := bufio.NewReader(src)
defer func() {
if err != nil {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
}
io.Copy(dst, reader)
}()
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
r.Cache.Add(destAddr, destAddr)
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
return
}
isHTTP, err := r.isHTTP(reader)
if err != nil {
err = fmt.Errorf("isHTTP: %w", err)
return
}
if !isHTTP {
r.Cache.Add(destAddr, destAddr)
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
return
}
var req *http.Request
var pass bool
// HTTP request loop (handles keep-alive)
for {
if isHTTP, err = r.isHTTP(reader); err != nil {
err = fmt.Errorf("isHTTP: %w", err)
return
}
if !isHTTP {
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
if isWebSocket(h2) {
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
} else {
r.Cache.Add(destAddr, destAddr)
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
}
return
}
if req, err = http.ReadRequest(reader); err != nil {
err = fmt.Errorf("http.ReadRequest: %w", err)
return
}
if pass, err = r.RewriteAndForward(dst, req, destAddr, srcAddr); pass {
return
}
}
}
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
// Fast check: peek first word to see if it's a known HTTP method
const maxMethodLen = 7
hintSlice, err := reader.Peek(maxMethodLen)
if err != nil {
return false, err
}
hint := string(hintSlice)
method, _, _ := strings.Cut(hint, " ")
_, exists := httpMethods[method]
return exists, nil
}
// buildNewUA returns either a partial replacement (regex) or full overwrite.
func (r *Rewriter) buildNewUA(originUA string) string {
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
// buildUserAgent returns either a partial replacement (regex) or full overwrite.
func (r *Rewriter) buildUserAgent(originUA string) string {
if r.partialReplace && r.uaRegex != nil && r.pattern != "" {
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
if err != nil {
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
@ -237,131 +76,117 @@ func (r *Rewriter) buildNewUA(originUA string) string {
return r.payloadUA
}
func (r *Rewriter) inWhitelist(ua string) bool {
_, ok := r.whitelist[ua]
return ok
func (r *Rewriter) ShouldRewrite(req *http.Request, srcAddr, destAddr string) bool {
originalUA := req.Header.Get("User-Agent")
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: (%s)", originalUA))
var err error
matches := true
isWhitelist := r.inWhitelist(originalUA)
if !isWhitelist && r.pattern != "" {
matches, err = r.uaRegex.MatchString(originalUA)
if err != nil {
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Match Error: %s", err.Error()))
matches = true
}
}
if isWhitelist {
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
r.Cache.Add(destAddr, struct{}{})
}
if !matches {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
}
hit := !isWhitelist && matches
if !hit {
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
Host: destAddr,
UA: originalUA,
})
}
return hit
}
// peekLineSlice reads a line from bufio.Reader without consuming it.
// returns the line bytes (without CRLF) or error.
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
const chunkSize = 256
var line []byte
func (r *Rewriter) Rewrite(req *http.Request, srcAddr string, destAddr string) *http.Request {
originalUA := req.Header.Get("User-Agent")
rewritedUA := r.buildUserAgent(originalUA)
req.Header.Set("User-Agent", rewritedUA)
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
statistics.AddRewriteRecord(&statistics.RewriteRecord{
Host: destAddr,
OriginalUA: originalUA,
MockedUA: rewritedUA,
})
return req
}
func (r *Rewriter) Forward(dst net.Conn, req *http.Request) error {
if err := req.Write(dst); err != nil {
return fmt.Errorf("req.Write: %w", err)
}
req.Body.Close()
return nil
}
// Process handles the proxying with UA rewriting logic.
func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
reader := bufio.NewReader(src)
defer func() {
if err != nil {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Process: %s", err.Error()))
}
io.Copy(dst, reader)
}()
if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) {
r.Cache.Add(destAddr, struct{}{})
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
return
}
var isHTTP bool
if isHTTP, err = sniff.SniffHTTP(reader); err != nil {
err = fmt.Errorf("sniff.SniffHTTP: %w", err)
return
}
if !isHTTP {
r.Cache.Add(destAddr, struct{}{})
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to cache")
return
}
var req *http.Request
offset := 0
for {
// Ensure there is data in the buffer
n := br.Buffered()
if n == 0 {
// No data in buffer, try to fill it
_, err := br.Peek(1)
if err != nil {
return nil, err
}
n = br.Buffered()
if isHTTP, err = sniff.SniffHTTPFast(reader); err != nil {
err = fmt.Errorf("isHTTP: %w", err)
return
}
// Limit to chunkSize
if n > chunkSize {
n = chunkSize
if !isHTTP {
r.Cache.Add(destAddr, struct{}{})
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
return
}
buf, err := br.Peek(offset + n)
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
return nil, err
if req, err = http.ReadRequest(reader); err != nil {
err = fmt.Errorf("http.ReadRequest: %w", err)
return
}
data := buf[offset:]
if i := bytes.IndexByte(data, '\n'); i >= 0 {
line = append(line, data[:i]...)
// Remove trailing CR if present
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
return line, nil
if r.ShouldRewrite(req, srcAddr, destAddr) {
req = r.Rewrite(req, srcAddr, destAddr)
}
line = append(line, data...)
offset += len(data)
// If EOF reached without finding a newline, return what we have
if errors.Is(err, io.EOF) {
return line, io.EOF
if err = r.Forward(dst, req); err != nil {
err = fmt.Errorf("r.forward: %w", err)
return
}
if req.Header.Get("Upgrade") == "websocket" && req.Header.Get("Connection") == "Upgrade" {
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket Upgrade detected, switching to raw proxy")
return
}
}
}
// peekLineString reads a line from bufio.Reader without consuming it.
// returns the line string (without CRLF) or error.
func peekLineString(br *bufio.Reader) (string, error) {
lineBytes, err := peekLineSlice(br)
if err != nil {
return "", err
}
return string(lineBytes), nil
}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
method, rest, ok1 := strings.Cut(line, " ")
requestURI, proto, ok2 := strings.Cut(rest, " ")
if !ok1 || !ok2 {
return "", "", "", false
}
return method, requestURI, proto, true
}
func isWebSocket(header []byte) bool {
if len(header) < 2 {
return false
}
b0 := header[0]
b1 := header[1]
rsv := b0 & 0x70 // RSV1-3
opcode := b0 & 0x0F // opcode
mask := b1 & 0x80 // MASK
// requested frames from client to server must be masked
if mask == 0 {
return false
}
// Control frames must have FIN set
if rsv != 0 {
return false
}
// opcode must be in valid range
if opcode > 0xA {
return false
}
// payload length
payloadLen := b1 & 0x7F
if payloadLen > 0 && payloadLen <= 125 {
return true
}
return true
}
func isTLSClientHello(reader *bufio.Reader) bool {
header, err := reader.Peek(3)
if err != nil {
return false
}
// TLS record type 0x16 = Handshake
if header[0] != 0x16 {
return false
}
// TLS version
versionMajor := header[1]
versionMinor := header[2]
if versionMajor != 0x03 {
return false
}
if versionMinor < 0x01 || versionMinor > 0x04 {
return false
}
return true
}

View File

@ -1,7 +1,6 @@
package rewrite
import (
"bufio"
"bytes"
"io"
"net"
@ -46,35 +45,11 @@ func TestNewRewriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, cfg.PayloadUA, rewriter.payloadUA)
assert.Equal(t, cfg.UAPattern, rewriter.pattern)
assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace)
assert.Equal(t, cfg.EnablePartialReplace, rewriter.partialReplace)
assert.NotNil(t, rewriter.uaRegex)
assert.NotNil(t, rewriter.Cache)
}
func TestIsHTTP(t *testing.T) {
r := newTestRewriter(t)
tests := []struct {
name string
input string
expected bool
}{
{"HTTP Get", "GET / HTTP/1.1\r\n", true},
{"HTTP Post", "POST /test HTTP/1.1\r\n", true},
{"HTTP Connect", "CONNECT example.com:443 HTTP/1.1\r\n", true},
{"Not HTTP", "HELLO WORLD\r\n", false},
{"Not HTTP", "SSH-2.0-OpenSSH_8.4\r\n", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := bufio.NewReader(strings.NewReader(tt.input))
isHTTP, _ := r.isHTTP(reader)
assert.Equal(t, tt.expected, isHTTP)
})
}
}
func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
r := newTestRewriter(t)
@ -83,7 +58,7 @@ func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
dstBuf := &bytes.Buffer{}
dst := &mockConn{Reader: nil, Writer: dstBuf}
r.ProxyHTTPOrRaw(dst, src, "example.com:80", "srcAddr")
r.Process(dst, src, "example.com:80", "srcAddr")
out := dstBuf.String()
assert.Contains(t, out, "User-Agent: MockUA/1.0")

View File

@ -58,7 +58,7 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
}
defer target.Close()
_, err = s.rw.RewriteAndForward(target, req, req.Host, req.RemoteAddr)
err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -101,6 +101,18 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
s.ForwardTCP(client, dest, destAddr)
}
func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) {
rw := s.rw
if rw.ShouldRewrite(req, srcAddr, dstAddr) {
req = rw.Rewrite(req, srcAddr, dstAddr)
}
if err = rw.Forward(target, req); err != nil {
err = fmt.Errorf("r.forward: %w", err)
return
}
return nil
}
func (s *Server) HandleClient(client net.Conn) {
}

View File

@ -60,7 +60,7 @@ func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddr string) {
io.Copy(dst, src)
return
}
_ = rw.ProxyHTTPOrRaw(dst, src, destAddr, srcAddr)
_ = rw.Process(dst, src, destAddr, srcAddr)
}
func GetConnFD(conn net.Conn) (fd int, err error) {

View File

@ -0,0 +1,75 @@
package sniff
import (
"bufio"
"bytes"
"errors"
"io"
"time"
)
var ErrPeekTimeout = errors.New("peek timeout")
// peekLineSlice reads a line from bufio.Reader without consuming it.
// returns the line bytes (without CRLF) or error.
func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) {
var line []byte
peekSize := maxSize
if peekSize == 0 {
return nil, io.EOF
}
if buffered := br.Buffered(); buffered < peekSize {
peekSize = buffered
}
buf, err := br.Peek(peekSize)
if err != nil {
return nil, err
}
if i := bytes.IndexByte(buf, '\n'); i >= 0 {
line = append(line, buf[:i]...)
// Remove trailing CR if present
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
return line, nil
}
return nil, io.EOF
}
// peekLineString reads a line from bufio.Reader without consuming it.
// returns the line string (without CRLF) or error.
func peekLineString(br *bufio.Reader, maxSize int) (string, error) {
lineBytes, err := peekLineSlice(br, maxSize)
if err != nil {
return "", err
}
return string(lineBytes), nil
}
// PeekWithTimeout peeks n bytes from bufio.Reader with a timeout.
func PeekWithTimeout(r *bufio.Reader, n int, timeout time.Duration) ([]byte, error) {
if buffered := r.Buffered(); buffered >= n {
data, err := r.Peek(n)
return data, err
}
type result struct {
data []byte
err error
}
ch := make(chan result, 1)
go func() {
data, err := r.Peek(n)
ch <- result{data, err}
}()
select {
case res := <-ch:
return res.data, res.err
case <-time.After(timeout):
return nil, ErrPeekTimeout
}
}

View File

@ -0,0 +1,75 @@
package sniff
import (
"bufio"
"strings"
"time"
)
// HTTP methods used to detect HTTP by request line.
var methods = [...]string{"GET", "POST", "HEAD", "CONNECT", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE"}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
method, rest, ok1 := strings.Cut(line, " ")
requestURI, proto, ok2 := strings.Cut(rest, " ")
if !ok1 || !ok2 {
return "", "", "", false
}
return method, requestURI, proto, true
}
// beginWithHTTPMethod peeks the first few bytes to check for known HTTP method prefixes.
func beginWithHTTPMethod(reader *bufio.Reader) (bool, error) {
const maxMethodLen = 7
const minMethodLen = 3
var hint []byte
hint, err := PeekWithTimeout(reader, maxMethodLen, 3*time.Second)
if err != nil {
if err != ErrPeekTimeout {
return false, err
}
hint, err = PeekWithTimeout(reader, minMethodLen+1, time.Second)
if err != nil {
return false, err
}
}
method, _, _ := strings.Cut(string(hint), " ")
for _, m := range methods {
if method == m {
return true, nil
}
}
return false, nil
}
// SniffHTTP peeks the first few bytes and checks for a known HTTP method prefix.
func SniffHTTP(reader *bufio.Reader) (bool, error) {
// Fast check: peek first word to see if it's a known HTTP method
beginHTTP, err := beginWithHTTPMethod(reader)
if err != nil {
return false, err
}
// Detailed check: parse request line
line, err := peekLineString(reader, 128)
if err != nil {
return beginHTTP, nil
}
_, _, proto, ok := parseRequestLine(line)
if !ok {
return beginHTTP, nil
}
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
return false, nil
}
return beginHTTP, nil
}
func SniffHTTPFast(reader *bufio.Reader) (bool, error) {
beginHTTP, err := beginWithHTTPMethod(reader)
if err != nil {
return false, err
}
return beginHTTP, nil
}

25
src/internal/sniff/tls.go Normal file
View File

@ -0,0 +1,25 @@
package sniff
import "bufio"
func SniffTLSClientHello(reader *bufio.Reader) bool {
header, err := reader.Peek(3)
if err != nil {
return false
}
// TLS record type 0x16 = Handshake
if header[0] != 0x16 {
return false
}
// TLS version
versionMajor := header[1]
versionMinor := header[2]
if versionMajor != 0x03 {
return false
}
if versionMinor < 0x01 || versionMinor > 0x04 {
return false
}
return true
}

View File

@ -0,0 +1,34 @@
package sniff
func SniffWebSocket(header []byte) bool {
if len(header) < 2 {
return false
}
b0 := header[0]
b1 := header[1]
rsv := b0 & 0x70 // RSV1-3
opcode := b0 & 0x0F // opcode
mask := b1 & 0x80 // MASK
// requested frames from client to server must be masked
if mask == 0 {
return false
}
// Control frames must have FIN set
if rsv != 0 {
return false
}
// opcode must be in valid range
if opcode > 0xA {
return false
}
// payload length
payloadLen := b1 & 0x7F
if payloadLen > 0 && payloadLen <= 125 {
return true
}
return true
}