mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 16:57:08 +00:00
refactor: improve HTTP method detection in rewriter
This commit is contained in:
parent
aabbd76009
commit
963e5bc775
@ -10,4 +10,9 @@ require (
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
)
|
||||
|
||||
require github.com/stretchr/testify v1.8.2 // indirect
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/testify v1.11.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@ -10,16 +10,13 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
|
||||
@ -72,3 +72,19 @@ func LogHeader(version string, cfg *config.Config) {
|
||||
logrus.Infof("Enable Partial Replace: %v", cfg.EnablePartialReplace)
|
||||
logrus.Infof("Log level: %s", cfg.LogLevel)
|
||||
}
|
||||
|
||||
func LogDebugWithAddr(src string, dest string, msg string) {
|
||||
logrus.Debugf("[%s -> %s] %s", src, dest, msg)
|
||||
}
|
||||
|
||||
func LogInfoWithAddr(src string, dest string, msg string) {
|
||||
logrus.Infof("[%s -> %s] %s", src, dest, msg)
|
||||
}
|
||||
|
||||
func LogWarnWithAddr(src string, dest string, msg string) {
|
||||
logrus.Warnf("[%s -> %s] %s", src, dest, msg)
|
||||
}
|
||||
|
||||
func LogErrorWithAddr(src string, dest string, msg string) {
|
||||
logrus.Errorf("[%s -> %s] %s", src, dest, msg)
|
||||
}
|
||||
|
||||
@ -2,7 +2,9 @@ package rewrite
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -14,6 +16,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/sunbk201/ua3f/internal/config"
|
||||
"github.com/sunbk201/ua3f/internal/log"
|
||||
"github.com/sunbk201/ua3f/internal/statistics"
|
||||
)
|
||||
|
||||
@ -24,7 +27,17 @@ const (
|
||||
)
|
||||
|
||||
// HTTP methods used to detect HTTP by request line.
|
||||
var httpMethods = []string{"GET", "POST", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", "CONNECT"}
|
||||
var httpMethods = map[string]struct{}{
|
||||
"GET": {},
|
||||
"POST": {},
|
||||
"HEAD": {},
|
||||
"PUT": {},
|
||||
"PATCH": {},
|
||||
"DELETE": {},
|
||||
"OPTIONS": {},
|
||||
"TRACE": {},
|
||||
"CONNECT": {},
|
||||
}
|
||||
|
||||
// Hardcoded whitelist of UAs that should be left untouched.
|
||||
var defaultWhitelist = []string{
|
||||
@ -74,56 +87,70 @@ func New(cfg *config.Config) (*Rewriter, error) {
|
||||
// - If target in LRU cache: pass-through (raw).
|
||||
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
|
||||
// - Else: mark target in LRU and pass-through.
|
||||
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort string) error {
|
||||
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (err error) {
|
||||
srcAddr := src.RemoteAddr().String()
|
||||
|
||||
// Fast path: known pass-through
|
||||
if r.cache.Contains(destAddrPort) {
|
||||
logrus.Debugf("Hit LRU Relay Cache: %s", destAddrPort)
|
||||
_, _ = io.Copy(dst, src)
|
||||
if r.cache.Contains(destAddr) {
|
||||
log.LogDebugWithAddr(src.RemoteAddr().String(), destAddr, "LRU Relay Cache Hit, pass-through")
|
||||
io.Copy(dst, src)
|
||||
return nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(src)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw Error: %s", err.Error()))
|
||||
}
|
||||
io.Copy(dst, reader)
|
||||
}()
|
||||
|
||||
isHTTP, err := r.isHTTP(reader)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), ErrUseClosedConn) {
|
||||
logrus.Warnf("[%s] isHTTP error: %s", destAddrPort, err.Error())
|
||||
return err
|
||||
}
|
||||
// Other read errors terminate the direction.
|
||||
return err
|
||||
err = fmt.Errorf("isHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isHTTP {
|
||||
r.cache.Add(destAddrPort, destAddrPort)
|
||||
logrus.Debugf("Not HTTP, Add LRU Relay Cache: %s, Cache Len: %d", destAddrPort, r.cache.Len())
|
||||
_, _ = io.Copy(dst, reader)
|
||||
return nil
|
||||
r.cache.Add(destAddr, destAddr)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||
return
|
||||
}
|
||||
|
||||
srcAddr := src.RemoteAddr().String()
|
||||
var req *http.Request
|
||||
|
||||
// HTTP request loop (handles keep-alive)
|
||||
for {
|
||||
req, err := http.ReadRequest(reader)
|
||||
isHTTP, err = r.isHTTP(reader)
|
||||
if err != nil {
|
||||
r.logReadErr(destAddrPort, src, err)
|
||||
return err
|
||||
err = fmt.Errorf("isHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
if !isHTTP {
|
||||
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
|
||||
if isWebSocket(h2) {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
|
||||
} else {
|
||||
r.cache.Add(destAddr, destAddr)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||
}
|
||||
return
|
||||
}
|
||||
req, err = http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("http.ReadRequest: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
originalUA := req.Header.Get("User-Agent")
|
||||
|
||||
// No UA header: pass-through after writing this first request
|
||||
if originalUA == "" {
|
||||
r.cache.Add(destAddrPort, destAddrPort)
|
||||
logrus.Debugf("[%s] Not found User-Agent, Add LRU Relay Cache, Cache Len: %d",
|
||||
destAddrPort, r.cache.Len())
|
||||
if err := req.Write(dst); err != nil {
|
||||
logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error())
|
||||
return err
|
||||
r.cache.Add(destAddr, destAddr)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
}
|
||||
_, _ = io.Copy(dst, reader)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
isWhitelist := r.inWhitelist(originalUA)
|
||||
@ -131,8 +158,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin
|
||||
if r.pattern != "" {
|
||||
matches, err = r.uaRegex.MatchString(originalUA)
|
||||
if err != nil {
|
||||
logrus.Errorf("[%s][%s] User-Agent Regex Pattern Match Error: %s",
|
||||
destAddrPort, srcAddr, err.Error())
|
||||
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
|
||||
matches = true
|
||||
}
|
||||
}
|
||||
@ -140,38 +166,33 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin
|
||||
// If UA is whitelisted or does not match target pattern, write once then pass-through.
|
||||
if isWhitelist || !matches {
|
||||
if !matches {
|
||||
logrus.Debugf("[%s][%s] Not Hit User-Agent Pattern: %s",
|
||||
destAddrPort, srcAddr, originalUA)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
||||
}
|
||||
if isWhitelist {
|
||||
logrus.Debugf("[%s][%s] Hit User-Agent Whitelist: %s, Add LRU Relay Cache, Cache Len: %d",
|
||||
destAddrPort, srcAddr, originalUA, r.cache.Len())
|
||||
r.cache.Add(destAddrPort, destAddrPort)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
||||
r.cache.Add(destAddr, destAddr)
|
||||
}
|
||||
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
||||
Host: destAddrPort,
|
||||
Host: destAddr,
|
||||
UA: originalUA,
|
||||
})
|
||||
if err := req.Write(dst); err != nil {
|
||||
logrus.Errorf("[%s][%s] write error: %s", destAddrPort, srcAddr, err.Error())
|
||||
return err
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
}
|
||||
_, _ = io.Copy(dst, reader)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite UA and forward the request (including body)
|
||||
logrus.Debugf("[%s][%s] Hit User-Agent: %s", destAddrPort, srcAddr, originalUA)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent: %s", originalUA))
|
||||
mockedUA := r.buildNewUA(originalUA)
|
||||
req.Header.Set("User-Agent", mockedUA)
|
||||
if err := req.Write(dst); err != nil {
|
||||
logrus.Errorf("[%s][%s] write error after replace user-agent: %s",
|
||||
destAddrPort, srcAddr, err.Error())
|
||||
return err
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
||||
Host: destAddrPort,
|
||||
Host: destAddr,
|
||||
OriginalUA: originalUA,
|
||||
MockedUA: mockedUA,
|
||||
})
|
||||
@ -180,20 +201,32 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddrPort strin
|
||||
|
||||
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
||||
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
|
||||
buf, err := reader.Peek(7)
|
||||
// Fast check: peek first word to see if it's a known HTTP method
|
||||
const maxMethodLen = 7
|
||||
hintSlice, err := reader.Peek(maxMethodLen)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "EOF") {
|
||||
logrus.Debugf("Peek EOF: %s", err.Error())
|
||||
} else {
|
||||
logrus.Errorf("Peek error: %s", err.Error())
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
hint := string(buf)
|
||||
for _, m := range httpMethods {
|
||||
if strings.HasPrefix(hint, m) {
|
||||
return true, nil
|
||||
}
|
||||
hint := string(hintSlice)
|
||||
method, _, _ := strings.Cut(hint, " ")
|
||||
if _, exists := httpMethods[method]; !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Detailed check: parse request line
|
||||
line, err := peekLineString(reader)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
method, _, proto, ok := parseRequestLine(line)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
|
||||
return false, nil
|
||||
}
|
||||
if _, exists := httpMethods[method]; exists {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@ -216,18 +249,104 @@ func (r *Rewriter) inWhitelist(ua string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *Rewriter) logReadErr(destAddrPort string, src net.Conn, err error) {
|
||||
remote := src.RemoteAddr().String()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
logrus.Debugf("[%s][%s] read EOF in first phase", destAddrPort, remote)
|
||||
case strings.Contains(err.Error(), ErrUseClosedConn):
|
||||
logrus.Debugf("[%s][%s] read closed in first phase: %s", destAddrPort, remote, err.Error())
|
||||
case strings.Contains(err.Error(), ErrConnResetByPeer):
|
||||
logrus.Debugf("[%s][%s] read reset in first phase: %s", destAddrPort, remote, err.Error())
|
||||
case strings.Contains(err.Error(), ErrIOTimeout):
|
||||
logrus.Debugf("[%s][%s] read timeout in first phase: %s", destAddrPort, remote, err.Error())
|
||||
default:
|
||||
logrus.Errorf("[%s][%s] read error in first phase: %s", destAddrPort, remote, err.Error())
|
||||
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
||||
// returns the line bytes (without CRLF) or error.
|
||||
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
|
||||
const chunkSize = 256
|
||||
var line []byte
|
||||
|
||||
offset := 0
|
||||
for {
|
||||
// Ensure there is data in the buffer
|
||||
n := br.Buffered()
|
||||
if n == 0 {
|
||||
// No data in buffer, try to fill it
|
||||
_, err := br.Peek(1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n = br.Buffered()
|
||||
}
|
||||
|
||||
// Limit to chunkSize
|
||||
if n > chunkSize {
|
||||
n = chunkSize
|
||||
}
|
||||
|
||||
buf, err := br.Peek(offset + n)
|
||||
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := buf[offset:]
|
||||
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
||||
line = append(line, data[:i]...)
|
||||
// Remove trailing CR if present
|
||||
if len(line) > 0 && line[len(line)-1] == '\r' {
|
||||
line = line[:len(line)-1]
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
line = append(line, data...)
|
||||
offset += len(data)
|
||||
|
||||
// If EOF reached without finding a newline, return what we have
|
||||
if errors.Is(err, io.EOF) {
|
||||
return line, io.EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// peekLineString reads a line from bufio.Reader without consuming it.
|
||||
// returns the line string (without CRLF) or error.
|
||||
func peekLineString(br *bufio.Reader) (string, error) {
|
||||
lineBytes, err := peekLineSlice(br)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(lineBytes), nil
|
||||
}
|
||||
|
||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||
method, rest, ok1 := strings.Cut(line, " ")
|
||||
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
||||
if !ok1 || !ok2 {
|
||||
return "", "", "", false
|
||||
}
|
||||
return method, requestURI, proto, true
|
||||
}
|
||||
|
||||
func isWebSocket(header []byte) bool {
|
||||
if len(header) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
b0 := header[0]
|
||||
b1 := header[1]
|
||||
|
||||
rsv := b0 & 0x70 // RSV1-3
|
||||
opcode := b0 & 0x0F // opcode
|
||||
mask := b1 & 0x80 // MASK
|
||||
|
||||
// requested frames from client to server must be masked
|
||||
if mask == 0 {
|
||||
return false
|
||||
}
|
||||
// Control frames must have FIN set
|
||||
if rsv != 0 {
|
||||
return false
|
||||
}
|
||||
// opcode must be in valid range
|
||||
if opcode > 0xA {
|
||||
return false
|
||||
}
|
||||
// payload length
|
||||
payloadLen := b1 & 0x7F
|
||||
if payloadLen > 0 && payloadLen <= 125 {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user