UA3F/src/internal/rewrite/rewriter.go
2025-11-01 00:30:27 +08:00

380 lines
9.3 KiB
Go

package rewrite
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/dlclark/regexp2"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/sirupsen/logrus"
"github.com/sunbk201/ua3f/internal/config"
"github.com/sunbk201/ua3f/internal/log"
"github.com/sunbk201/ua3f/internal/statistics"
)
const (
ErrUseClosedConn = "use of closed network connection"
ErrConnResetByPeer = "connection reset by peer"
ErrIOTimeout = "i/o timeout"
)
// HTTP methods used to detect HTTP by request line.
var httpMethods = map[string]struct{}{
"GET": {},
"POST": {},
"HEAD": {},
"PUT": {},
"PATCH": {},
"DELETE": {},
"OPTIONS": {},
"TRACE": {},
"CONNECT": {},
}
// Hardcoded whitelist of UAs that should be left untouched.
var defaultWhitelist = []string{
"MicroMessenger Client",
"ByteDancePcdn",
"Go-http-client/1.1",
}
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
type Rewriter struct {
payloadUA string
pattern string
enablePartialReplace bool
uaRegex *regexp2.Regexp
cache *expirable.LRU[string, string]
whitelist map[string]struct{}
}
// New constructs a Rewriter from config. Compiles regex and allocates cache.
func New(cfg *config.Config) (*Rewriter, error) {
// UA pattern is compiled with case-insensitive prefix (?i)
pattern := "(?i)" + cfg.UAPattern
uaRegex, err := regexp2.Compile(pattern, regexp2.None)
if err != nil {
return nil, err
}
cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute)
whitelist := make(map[string]struct{}, len(defaultWhitelist))
for _, s := range defaultWhitelist {
whitelist[s] = struct{}{}
}
return &Rewriter{
payloadUA: cfg.PayloadUA,
pattern: cfg.UAPattern,
enablePartialReplace: cfg.EnablePartialReplace,
uaRegex: uaRegex,
cache: cache,
whitelist: whitelist,
}, nil
}
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
// - If target in LRU cache: pass-through (raw).
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
// - Else: mark target in LRU and pass-through.
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (err error) {
srcAddr := src.RemoteAddr().String()
// Fast path: known pass-through
if r.cache.Contains(destAddr) {
log.LogDebugWithAddr(src.RemoteAddr().String(), destAddr, "LRU Relay Cache Hit, pass-through")
io.Copy(dst, src)
return nil
}
reader := bufio.NewReader(src)
defer func() {
if err != nil {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
}
io.Copy(dst, reader)
}()
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
r.cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "TLS ClientHello detected")
return
}
isHTTP, err := r.isHTTP(reader)
if err != nil {
err = fmt.Errorf("isHTTP: %w", err)
return
}
if !isHTTP {
r.cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
return
}
var req *http.Request
// HTTP request loop (handles keep-alive)
for {
isHTTP, err = r.isHTTP(reader)
if err != nil {
err = fmt.Errorf("isHTTP: %w", err)
return
}
if !isHTTP {
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
if isWebSocket(h2) {
log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
} else {
r.cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
}
return
}
req, err = http.ReadRequest(reader)
if err != nil {
err = fmt.Errorf("http.ReadRequest: %w", err)
return
}
originalUA := req.Header.Get("User-Agent")
// No UA header: pass-through after writing this first request
if originalUA == "" {
r.cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
}
return
}
isWhitelist := r.inWhitelist(originalUA)
matches := true
if r.pattern != "" {
matches, err = r.uaRegex.MatchString(originalUA)
if err != nil {
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
matches = true
}
}
// If UA is whitelisted or does not match target pattern, write once then pass-through.
if isWhitelist || !matches {
if !matches {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
}
if isWhitelist {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
r.cache.Add(destAddr, destAddr)
}
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
Host: destAddr,
UA: originalUA,
})
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
}
return
}
// Rewrite UA and forward the request (including body)
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent: %s", originalUA))
mockedUA := r.buildNewUA(originalUA)
req.Header.Set("User-Agent", mockedUA)
if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err)
return
}
statistics.AddRewriteRecord(&statistics.RewriteRecord{
Host: destAddr,
OriginalUA: originalUA,
MockedUA: mockedUA,
})
}
}
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
// Fast check: peek first word to see if it's a known HTTP method
const maxMethodLen = 7
hintSlice, err := reader.Peek(maxMethodLen)
if err != nil {
return false, err
}
hint := string(hintSlice)
method, _, _ := strings.Cut(hint, " ")
if _, exists := httpMethods[method]; !exists {
return false, nil
}
// Detailed check: parse request line
line, err := peekLineString(reader)
if err != nil {
return false, err
}
method, _, proto, ok := parseRequestLine(line)
if !ok {
return false, nil
}
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
return false, nil
}
if _, exists := httpMethods[method]; exists {
return true, nil
}
return false, nil
}
// buildNewUA returns either a partial replacement (regex) or full overwrite.
func (r *Rewriter) buildNewUA(originUA string) string {
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
if err != nil {
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
return r.payloadUA
}
return newUA
}
return r.payloadUA
}
func (r *Rewriter) inWhitelist(ua string) bool {
_, ok := r.whitelist[ua]
return ok
}
// peekLineSlice reads a line from bufio.Reader without consuming it.
// returns the line bytes (without CRLF) or error.
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
const chunkSize = 256
var line []byte
offset := 0
for {
// Ensure there is data in the buffer
n := br.Buffered()
if n == 0 {
// No data in buffer, try to fill it
_, err := br.Peek(1)
if err != nil {
return nil, err
}
n = br.Buffered()
}
// Limit to chunkSize
if n > chunkSize {
n = chunkSize
}
buf, err := br.Peek(offset + n)
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
return nil, err
}
data := buf[offset:]
if i := bytes.IndexByte(data, '\n'); i >= 0 {
line = append(line, data[:i]...)
// Remove trailing CR if present
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
return line, nil
}
line = append(line, data...)
offset += len(data)
// If EOF reached without finding a newline, return what we have
if errors.Is(err, io.EOF) {
return line, io.EOF
}
}
}
// peekLineString reads a line from bufio.Reader without consuming it.
// returns the line string (without CRLF) or error.
func peekLineString(br *bufio.Reader) (string, error) {
lineBytes, err := peekLineSlice(br)
if err != nil {
return "", err
}
return string(lineBytes), nil
}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
method, rest, ok1 := strings.Cut(line, " ")
requestURI, proto, ok2 := strings.Cut(rest, " ")
if !ok1 || !ok2 {
return "", "", "", false
}
return method, requestURI, proto, true
}
func isWebSocket(header []byte) bool {
if len(header) < 2 {
return false
}
b0 := header[0]
b1 := header[1]
rsv := b0 & 0x70 // RSV1-3
opcode := b0 & 0x0F // opcode
mask := b1 & 0x80 // MASK
// requested frames from client to server must be masked
if mask == 0 {
return false
}
// Control frames must have FIN set
if rsv != 0 {
return false
}
// opcode must be in valid range
if opcode > 0xA {
return false
}
// payload length
payloadLen := b1 & 0x7F
if payloadLen > 0 && payloadLen <= 125 {
return true
}
return true
}
func isTLSClientHello(reader *bufio.Reader) bool {
header, err := reader.Peek(3)
if err != nil {
return false
}
// TLS record type 0x16 = Handshake
if header[0] != 0x16 {
return false
}
// TLS version
versionMajor := header[1]
versionMinor := header[2]
if versionMajor != 0x03 {
return false
}
if versionMinor < 0x01 || versionMinor > 0x04 {
return false
}
return true
}