mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 16:57:08 +00:00
368 lines
9.3 KiB
Go
368 lines
9.3 KiB
Go
package rewrite
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"github.com/sunbk201/ua3f/internal/config"
|
|
"github.com/sunbk201/ua3f/internal/log"
|
|
"github.com/sunbk201/ua3f/internal/statistics"
|
|
)
|
|
|
|
const (
|
|
ErrUseClosedConn = "use of closed network connection"
|
|
ErrConnResetByPeer = "connection reset by peer"
|
|
ErrIOTimeout = "i/o timeout"
|
|
)
|
|
|
|
// HTTP methods used to detect HTTP by request line.
|
|
var httpMethods = map[string]struct{}{
|
|
"GET": {},
|
|
"POST": {},
|
|
"HEAD": {},
|
|
"PUT": {},
|
|
"PATCH": {},
|
|
"DELETE": {},
|
|
"OPTIONS": {},
|
|
"TRACE": {},
|
|
"CONNECT": {},
|
|
}
|
|
|
|
// Hardcoded whitelist of UAs that should be left untouched.
|
|
var defaultWhitelist = []string{
|
|
"MicroMessenger Client",
|
|
"ByteDancePcdn",
|
|
"Go-http-client/1.1",
|
|
}
|
|
|
|
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
|
|
type Rewriter struct {
|
|
payloadUA string
|
|
pattern string
|
|
enablePartialReplace bool
|
|
|
|
uaRegex *regexp2.Regexp
|
|
Cache *expirable.LRU[string, string]
|
|
whitelist map[string]struct{}
|
|
}
|
|
|
|
// New constructs a Rewriter from config. Compiles regex and allocates cache.
|
|
func New(cfg *config.Config) (*Rewriter, error) {
|
|
// UA pattern is compiled with case-insensitive prefix (?i)
|
|
pattern := "(?i)" + cfg.UAPattern
|
|
uaRegex, err := regexp2.Compile(pattern, regexp2.None)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute)
|
|
|
|
whitelist := make(map[string]struct{}, len(defaultWhitelist))
|
|
for _, s := range defaultWhitelist {
|
|
whitelist[s] = struct{}{}
|
|
}
|
|
|
|
return &Rewriter{
|
|
payloadUA: cfg.PayloadUA,
|
|
pattern: cfg.UAPattern,
|
|
enablePartialReplace: cfg.EnablePartialReplace,
|
|
uaRegex: uaRegex,
|
|
Cache: cache,
|
|
whitelist: whitelist,
|
|
}, nil
|
|
}
|
|
|
|
// RewriteAndForward rewrites the User-Agent header if needed and forwards the request.
|
|
// Returns pass=true if the request has been forwarded as-is (no rewrite).
|
|
func (r *Rewriter) RewriteAndForward(dst net.Conn, req *http.Request, destAddr string, srcAddr string) (pass bool, err error) {
|
|
|
|
originalUA := req.Header.Get("User-Agent")
|
|
|
|
// No UA header: pass-through after writing this first request
|
|
if originalUA == "" {
|
|
r.Cache.Add(destAddr, destAddr)
|
|
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
}
|
|
pass = true
|
|
return
|
|
}
|
|
|
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: %s", originalUA))
|
|
|
|
isWhitelist := r.inWhitelist(originalUA)
|
|
matches := true
|
|
if r.pattern != "" {
|
|
matches, err = r.uaRegex.MatchString(originalUA)
|
|
if err != nil {
|
|
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
|
|
matches = true
|
|
}
|
|
}
|
|
|
|
// If UA is whitelisted or does not match target pattern, write once then pass-through.
|
|
if isWhitelist || !matches {
|
|
if !matches {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
|
}
|
|
if isWhitelist {
|
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
|
r.Cache.Add(destAddr, destAddr)
|
|
}
|
|
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
|
Host: destAddr,
|
|
UA: originalUA,
|
|
})
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
}
|
|
pass = true
|
|
return
|
|
}
|
|
|
|
// Rewrite UA and forward the request (including body)
|
|
rewritedUA := r.buildNewUA(originalUA)
|
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
|
|
req.Header.Set("User-Agent", rewritedUA)
|
|
if err = req.Write(dst); err != nil {
|
|
err = fmt.Errorf("req.Write: %w", err)
|
|
pass = true
|
|
return
|
|
}
|
|
|
|
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
|
Host: destAddr,
|
|
OriginalUA: originalUA,
|
|
MockedUA: rewritedUA,
|
|
})
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
|
|
// - If target in LRU cache: pass-through (raw).
|
|
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
|
|
// - Else: mark target in LRU and pass-through.
|
|
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
|
|
reader := bufio.NewReader(src)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
|
|
}
|
|
io.Copy(dst, reader)
|
|
}()
|
|
|
|
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
|
|
r.Cache.Add(destAddr, destAddr)
|
|
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
|
|
return
|
|
}
|
|
isHTTP, err := r.isHTTP(reader)
|
|
if err != nil {
|
|
err = fmt.Errorf("isHTTP: %w", err)
|
|
return
|
|
}
|
|
if !isHTTP {
|
|
r.Cache.Add(destAddr, destAddr)
|
|
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
var pass bool
|
|
|
|
// HTTP request loop (handles keep-alive)
|
|
for {
|
|
if isHTTP, err = r.isHTTP(reader); err != nil {
|
|
err = fmt.Errorf("isHTTP: %w", err)
|
|
return
|
|
}
|
|
if !isHTTP {
|
|
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
|
|
if isWebSocket(h2) {
|
|
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
|
|
} else {
|
|
r.Cache.Add(destAddr, destAddr)
|
|
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
}
|
|
return
|
|
}
|
|
if req, err = http.ReadRequest(reader); err != nil {
|
|
err = fmt.Errorf("http.ReadRequest: %w", err)
|
|
return
|
|
}
|
|
if pass, err = r.RewriteAndForward(dst, req, destAddr, srcAddr); pass {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
|
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
|
|
// Fast check: peek first word to see if it's a known HTTP method
|
|
const maxMethodLen = 7
|
|
hintSlice, err := reader.Peek(maxMethodLen)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
hint := string(hintSlice)
|
|
method, _, _ := strings.Cut(hint, " ")
|
|
_, exists := httpMethods[method]
|
|
return exists, nil
|
|
}
|
|
|
|
// buildNewUA returns either a partial replacement (regex) or full overwrite.
|
|
func (r *Rewriter) buildNewUA(originUA string) string {
|
|
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
|
|
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
|
|
if err != nil {
|
|
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
|
|
return r.payloadUA
|
|
}
|
|
return newUA
|
|
}
|
|
return r.payloadUA
|
|
}
|
|
|
|
func (r *Rewriter) inWhitelist(ua string) bool {
|
|
_, ok := r.whitelist[ua]
|
|
return ok
|
|
}
|
|
|
|
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
|
// returns the line bytes (without CRLF) or error.
|
|
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
|
|
const chunkSize = 256
|
|
var line []byte
|
|
|
|
offset := 0
|
|
for {
|
|
// Ensure there is data in the buffer
|
|
n := br.Buffered()
|
|
if n == 0 {
|
|
// No data in buffer, try to fill it
|
|
_, err := br.Peek(1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
n = br.Buffered()
|
|
}
|
|
|
|
// Limit to chunkSize
|
|
if n > chunkSize {
|
|
n = chunkSize
|
|
}
|
|
|
|
buf, err := br.Peek(offset + n)
|
|
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
|
|
return nil, err
|
|
}
|
|
|
|
data := buf[offset:]
|
|
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
|
line = append(line, data[:i]...)
|
|
// Remove trailing CR if present
|
|
if len(line) > 0 && line[len(line)-1] == '\r' {
|
|
line = line[:len(line)-1]
|
|
}
|
|
return line, nil
|
|
}
|
|
|
|
line = append(line, data...)
|
|
offset += len(data)
|
|
|
|
// If EOF reached without finding a newline, return what we have
|
|
if errors.Is(err, io.EOF) {
|
|
return line, io.EOF
|
|
}
|
|
}
|
|
}
|
|
|
|
// peekLineString reads a line from bufio.Reader without consuming it.
|
|
// returns the line string (without CRLF) or error.
|
|
func peekLineString(br *bufio.Reader) (string, error) {
|
|
lineBytes, err := peekLineSlice(br)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(lineBytes), nil
|
|
}
|
|
|
|
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
|
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
|
method, rest, ok1 := strings.Cut(line, " ")
|
|
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
|
if !ok1 || !ok2 {
|
|
return "", "", "", false
|
|
}
|
|
return method, requestURI, proto, true
|
|
}
|
|
|
|
func isWebSocket(header []byte) bool {
|
|
if len(header) < 2 {
|
|
return false
|
|
}
|
|
|
|
b0 := header[0]
|
|
b1 := header[1]
|
|
|
|
rsv := b0 & 0x70 // RSV1-3
|
|
opcode := b0 & 0x0F // opcode
|
|
mask := b1 & 0x80 // MASK
|
|
|
|
// requested frames from client to server must be masked
|
|
if mask == 0 {
|
|
return false
|
|
}
|
|
// Control frames must have FIN set
|
|
if rsv != 0 {
|
|
return false
|
|
}
|
|
// opcode must be in valid range
|
|
if opcode > 0xA {
|
|
return false
|
|
}
|
|
// payload length
|
|
payloadLen := b1 & 0x7F
|
|
if payloadLen > 0 && payloadLen <= 125 {
|
|
return true
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func isTLSClientHello(reader *bufio.Reader) bool {
|
|
header, err := reader.Peek(3)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
// TLS record type 0x16 = Handshake
|
|
if header[0] != 0x16 {
|
|
return false
|
|
}
|
|
// TLS version
|
|
versionMajor := header[1]
|
|
versionMinor := header[2]
|
|
if versionMajor != 0x03 {
|
|
return false
|
|
}
|
|
if versionMinor < 0x01 || versionMinor > 0x04 {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|