mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-17 01:06:13 +00:00
353 lines
8.8 KiB
Go
353 lines
8.8 KiB
Go
package rewrite
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"github.com/sunbk201/ua3f/internal/config"
|
|
"github.com/sunbk201/ua3f/internal/log"
|
|
"github.com/sunbk201/ua3f/internal/statistics"
|
|
)
|
|
|
|
const (
|
|
ErrUseClosedConn = "use of closed network connection"
|
|
ErrConnResetByPeer = "connection reset by peer"
|
|
ErrIOTimeout = "i/o timeout"
|
|
)
|
|
|
|
// HTTP methods used to detect HTTP by request line.
|
|
var httpMethods = map[string]struct{}{
|
|
"GET": {},
|
|
"POST": {},
|
|
"HEAD": {},
|
|
"PUT": {},
|
|
"PATCH": {},
|
|
"DELETE": {},
|
|
"OPTIONS": {},
|
|
"TRACE": {},
|
|
"CONNECT": {},
|
|
}
|
|
|
|
// Hardcoded whitelist of UAs that should be left untouched.
|
|
var defaultWhitelist = []string{
|
|
"MicroMessenger Client",
|
|
"ByteDancePcdn",
|
|
"Go-http-client/1.1",
|
|
}
|
|
|
|
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
|
|
type Rewriter struct {
|
|
payloadUA string
|
|
pattern string
|
|
enablePartialReplace bool
|
|
|
|
uaRegex *regexp2.Regexp
|
|
cache *expirable.LRU[string, string]
|
|
whitelist map[string]struct{}
|
|
}
|
|
|
|
// New constructs a Rewriter from config. Compiles regex and allocates cache.
|
|
func New(cfg *config.Config) (*Rewriter, error) {
|
|
// UA pattern is compiled with case-insensitive prefix (?i)
|
|
pattern := "(?i)" + cfg.UAPattern
|
|
uaRegex, err := regexp2.Compile(pattern, regexp2.None)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cache := expirable.NewLRU[string, string](300, nil, 10*time.Minute)
|
|
|
|
whitelist := make(map[string]struct{}, len(defaultWhitelist))
|
|
for _, s := range defaultWhitelist {
|
|
whitelist[s] = struct{}{}
|
|
}
|
|
|
|
return &Rewriter{
|
|
payloadUA: cfg.PayloadUA,
|
|
pattern: cfg.UAPattern,
|
|
enablePartialReplace: cfg.EnablePartialReplace,
|
|
uaRegex: uaRegex,
|
|
cache: cache,
|
|
whitelist: whitelist,
|
|
}, nil
|
|
}
|
|
|
|
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
|
|
// - If target in LRU cache: pass-through (raw).
|
|
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
|
|
// - Else: mark target in LRU and pass-through.
|
|
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (err error) {
|
|
srcAddr := src.RemoteAddr().String()
|
|
|
|
// Fast path: known pass-through
|
|
if r.cache.Contains(destAddr) {
|
|
log.LogDebugWithAddr(src.RemoteAddr().String(), destAddr, "LRU Relay Cache Hit, pass-through")
|
|
io.Copy(dst, src)
|
|
return nil
|
|
}
|
|
|
|
reader := bufio.NewReader(src)
|
|
defer func() {
|
|
if err != nil {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
|
|
}
|
|
io.Copy(dst, reader)
|
|
}()
|
|
|
|
isHTTP, err := r.isHTTP(reader)
|
|
if err != nil {
|
|
err = fmt.Errorf("isHTTP: %w", err)
|
|
return
|
|
}
|
|
if !isHTTP {
|
|
r.cache.Add(destAddr, destAddr)
|
|
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
|
|
// HTTP request loop (handles keep-alive)
|
|
for {
|
|
isHTTP, err = r.isHTTP(reader)
|
|
if err != nil {
|
|
err = fmt.Errorf("isHTTP: %w", err)
|
|
return
|
|
}
|
|
if !isHTTP {
|
|
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
|
|
if isWebSocket(h2) {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
|
|
} else {
|
|
r.cache.Add(destAddr, destAddr)
|
|
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
}
|
|
return
|
|
}
|
|
req, err = http.ReadRequest(reader)
|
|
if err != nil {
|
|
err = fmt.Errorf("http.ReadRequest: %w", err)
|
|
return
|
|
}
|
|
|
|
originalUA := req.Header.Get("User-Agent")
|
|
|
|
// No UA header: pass-through after writing this first request
|
|
if originalUA == "" {
|
|
r.cache.Add(destAddr, destAddr)
|
|
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
isWhitelist := r.inWhitelist(originalUA)
|
|
matches := true
|
|
if r.pattern != "" {
|
|
matches, err = r.uaRegex.MatchString(originalUA)
|
|
if err != nil {
|
|
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
|
|
matches = true
|
|
}
|
|
}
|
|
|
|
// If UA is whitelisted or does not match target pattern, write once then pass-through.
|
|
if isWhitelist || !matches {
|
|
if !matches {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
|
}
|
|
if isWhitelist {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
|
r.cache.Add(destAddr, destAddr)
|
|
}
|
|
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
|
Host: destAddr,
|
|
UA: originalUA,
|
|
})
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Rewrite UA and forward the request (including body)
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent: %s", originalUA))
|
|
mockedUA := r.buildNewUA(originalUA)
|
|
req.Header.Set("User-Agent", mockedUA)
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
return
|
|
}
|
|
|
|
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
|
Host: destAddr,
|
|
OriginalUA: originalUA,
|
|
MockedUA: mockedUA,
|
|
})
|
|
}
|
|
}
|
|
|
|
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
|
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
|
|
// Fast check: peek first word to see if it's a known HTTP method
|
|
const maxMethodLen = 7
|
|
hintSlice, err := reader.Peek(maxMethodLen)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
hint := string(hintSlice)
|
|
method, _, _ := strings.Cut(hint, " ")
|
|
if _, exists := httpMethods[method]; !exists {
|
|
return false, nil
|
|
}
|
|
|
|
// Detailed check: parse request line
|
|
line, err := peekLineString(reader)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
method, _, proto, ok := parseRequestLine(line)
|
|
if !ok {
|
|
return false, nil
|
|
}
|
|
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
|
|
return false, nil
|
|
}
|
|
if _, exists := httpMethods[method]; exists {
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// buildNewUA returns either a partial replacement (regex) or full overwrite.
|
|
func (r *Rewriter) buildNewUA(originUA string) string {
|
|
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
|
|
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
|
|
if err != nil {
|
|
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
|
|
return r.payloadUA
|
|
}
|
|
return newUA
|
|
}
|
|
return r.payloadUA
|
|
}
|
|
|
|
func (r *Rewriter) inWhitelist(ua string) bool {
|
|
_, ok := r.whitelist[ua]
|
|
return ok
|
|
}
|
|
|
|
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
|
// returns the line bytes (without CRLF) or error.
|
|
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
|
|
const chunkSize = 256
|
|
var line []byte
|
|
|
|
offset := 0
|
|
for {
|
|
// Ensure there is data in the buffer
|
|
n := br.Buffered()
|
|
if n == 0 {
|
|
// No data in buffer, try to fill it
|
|
_, err := br.Peek(1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
n = br.Buffered()
|
|
}
|
|
|
|
// Limit to chunkSize
|
|
if n > chunkSize {
|
|
n = chunkSize
|
|
}
|
|
|
|
buf, err := br.Peek(offset + n)
|
|
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
|
|
return nil, err
|
|
}
|
|
|
|
data := buf[offset:]
|
|
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
|
line = append(line, data[:i]...)
|
|
// Remove trailing CR if present
|
|
if len(line) > 0 && line[len(line)-1] == '\r' {
|
|
line = line[:len(line)-1]
|
|
}
|
|
return line, nil
|
|
}
|
|
|
|
line = append(line, data...)
|
|
offset += len(data)
|
|
|
|
// If EOF reached without finding a newline, return what we have
|
|
if errors.Is(err, io.EOF) {
|
|
return line, io.EOF
|
|
}
|
|
}
|
|
}
|
|
|
|
// peekLineString reads a line from bufio.Reader without consuming it.
|
|
// returns the line string (without CRLF) or error.
|
|
func peekLineString(br *bufio.Reader) (string, error) {
|
|
lineBytes, err := peekLineSlice(br)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(lineBytes), nil
|
|
}
|
|
|
|
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
|
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
|
method, rest, ok1 := strings.Cut(line, " ")
|
|
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
|
if !ok1 || !ok2 {
|
|
return "", "", "", false
|
|
}
|
|
return method, requestURI, proto, true
|
|
}
|
|
|
|
func isWebSocket(header []byte) bool {
|
|
if len(header) < 2 {
|
|
return false
|
|
}
|
|
|
|
b0 := header[0]
|
|
b1 := header[1]
|
|
|
|
rsv := b0 & 0x70 // RSV1-3
|
|
opcode := b0 & 0x0F // opcode
|
|
mask := b1 & 0x80 // MASK
|
|
|
|
// requested frames from client to server must be masked
|
|
if mask == 0 {
|
|
return false
|
|
}
|
|
// Control frames must have FIN set
|
|
if rsv != 0 {
|
|
return false
|
|
}
|
|
// opcode must be in valid range
|
|
if opcode > 0xA {
|
|
return false
|
|
}
|
|
// payload length
|
|
payloadLen := b1 & 0x7F
|
|
if payloadLen > 0 && payloadLen <= 125 {
|
|
return true
|
|
}
|
|
|
|
return true
|
|
}
|