mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 16:57:08 +00:00
refactor: streamline rewriting
This commit is contained in:
parent
57a344e563
commit
ab0cf0efdc
@ -2,8 +2,6 @@ package rewrite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -17,44 +15,19 @@ import (
|
|||||||
|
|
||||||
"github.com/sunbk201/ua3f/internal/config"
|
"github.com/sunbk201/ua3f/internal/config"
|
||||||
"github.com/sunbk201/ua3f/internal/log"
|
"github.com/sunbk201/ua3f/internal/log"
|
||||||
|
"github.com/sunbk201/ua3f/internal/sniff"
|
||||||
"github.com/sunbk201/ua3f/internal/statistics"
|
"github.com/sunbk201/ua3f/internal/statistics"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ErrUseClosedConn = "use of closed network connection"
|
|
||||||
ErrConnResetByPeer = "connection reset by peer"
|
|
||||||
ErrIOTimeout = "i/o timeout"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HTTP methods used to detect HTTP by request line.
|
|
||||||
var httpMethods = map[string]struct{}{
|
|
||||||
"GET": {},
|
|
||||||
"POST": {},
|
|
||||||
"HEAD": {},
|
|
||||||
"PUT": {},
|
|
||||||
"PATCH": {},
|
|
||||||
"DELETE": {},
|
|
||||||
"OPTIONS": {},
|
|
||||||
"TRACE": {},
|
|
||||||
"CONNECT": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hardcoded whitelist of UAs that should be left untouched.
|
|
||||||
var defaultWhitelist = []string{
|
|
||||||
"MicroMessenger Client",
|
|
||||||
"ByteDancePcdn",
|
|
||||||
"Go-http-client/1.1",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
|
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
|
||||||
type Rewriter struct {
|
type Rewriter struct {
|
||||||
payloadUA string
|
payloadUA string
|
||||||
pattern string
|
pattern string
|
||||||
enablePartialReplace bool
|
partialReplace bool
|
||||||
|
|
||||||
uaRegex *regexp2.Regexp
|
uaRegex *regexp2.Regexp
|
||||||
Cache *expirable.LRU[string, string]
|
whitelist []string
|
||||||
whitelist map[string]struct{}
|
Cache *expirable.LRU[string, struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a Rewriter from config. Compiles regex and allocates cache.
|
// New constructs a Rewriter from config. Compiles regex and allocates cache.
|
||||||
@ -66,167 +39,33 @@ func New(cfg *config.Config) (*Rewriter, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute)
|
|
||||||
|
|
||||||
whitelist := make(map[string]struct{}, len(defaultWhitelist))
|
|
||||||
for _, s := range defaultWhitelist {
|
|
||||||
whitelist[s] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Rewriter{
|
return &Rewriter{
|
||||||
payloadUA: cfg.PayloadUA,
|
payloadUA: cfg.PayloadUA,
|
||||||
pattern: cfg.UAPattern,
|
pattern: cfg.UAPattern,
|
||||||
enablePartialReplace: cfg.EnablePartialReplace,
|
partialReplace: cfg.EnablePartialReplace,
|
||||||
uaRegex: uaRegex,
|
uaRegex: uaRegex,
|
||||||
Cache: cache,
|
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
|
||||||
whitelist: whitelist,
|
whitelist: []string{
|
||||||
|
"MicroMessenger Client",
|
||||||
|
"Bilibili Freedoooooom/MarkII",
|
||||||
|
"Go-http-client/1.1",
|
||||||
|
"ByteDancePcdn",
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RewriteAndForward rewrites the User-Agent header if needed and forwards the request.
|
func (r *Rewriter) inWhitelist(ua string) bool {
|
||||||
// Returns pass=true if the request has been forwarded as-is (no rewrite).
|
for _, w := range r.whitelist {
|
||||||
func (r *Rewriter) RewriteAndForward(dst net.Conn, req *http.Request, destAddr string, srcAddr string) (pass bool, err error) {
|
if w == ua {
|
||||||
|
return true
|
||||||
originalUA := req.Header.Get("User-Agent")
|
|
||||||
|
|
||||||
// No UA header: pass-through after writing this first request
|
|
||||||
if originalUA == "" {
|
|
||||||
r.Cache.Add(destAddr, destAddr)
|
|
||||||
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
|
|
||||||
if err = req.Write(dst); err != nil {
|
|
||||||
err = fmt.Errorf("req.Write: %w", err)
|
|
||||||
}
|
|
||||||
pass = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: %s", originalUA))
|
|
||||||
|
|
||||||
isWhitelist := r.inWhitelist(originalUA)
|
|
||||||
matches := true
|
|
||||||
if r.pattern != "" {
|
|
||||||
matches, err = r.uaRegex.MatchString(originalUA)
|
|
||||||
if err != nil {
|
|
||||||
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
|
|
||||||
matches = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
// If UA is whitelisted or does not match target pattern, write once then pass-through.
|
|
||||||
if isWhitelist || !matches {
|
|
||||||
if !matches {
|
|
||||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
|
||||||
}
|
|
||||||
if isWhitelist {
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
|
||||||
r.Cache.Add(destAddr, destAddr)
|
|
||||||
}
|
|
||||||
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
|
||||||
Host: destAddr,
|
|
||||||
UA: originalUA,
|
|
||||||
})
|
|
||||||
if err = req.Write(dst); err != nil {
|
|
||||||
err = fmt.Errorf("req.Write: %w", err)
|
|
||||||
}
|
|
||||||
pass = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrite UA and forward the request (including body)
|
|
||||||
rewritedUA := r.buildNewUA(originalUA)
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
|
|
||||||
req.Header.Set("User-Agent", rewritedUA)
|
|
||||||
if err = req.Write(dst); err != nil {
|
|
||||||
err = fmt.Errorf("req.Write: %w", err)
|
|
||||||
pass = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
|
||||||
Host: destAddr,
|
|
||||||
OriginalUA: originalUA,
|
|
||||||
MockedUA: rewritedUA,
|
|
||||||
})
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
|
// buildUserAgent returns either a partial replacement (regex) or full overwrite.
|
||||||
// - If target in LRU cache: pass-through (raw).
|
func (r *Rewriter) buildUserAgent(originUA string) string {
|
||||||
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
|
if r.partialReplace && r.uaRegex != nil && r.pattern != "" {
|
||||||
// - Else: mark target in LRU and pass-through.
|
|
||||||
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
|
|
||||||
reader := bufio.NewReader(src)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
|
|
||||||
}
|
|
||||||
io.Copy(dst, reader)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
|
|
||||||
r.Cache.Add(destAddr, destAddr)
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
isHTTP, err := r.isHTTP(reader)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("isHTTP: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !isHTTP {
|
|
||||||
r.Cache.Add(destAddr, destAddr)
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
var pass bool
|
|
||||||
|
|
||||||
// HTTP request loop (handles keep-alive)
|
|
||||||
for {
|
|
||||||
if isHTTP, err = r.isHTTP(reader); err != nil {
|
|
||||||
err = fmt.Errorf("isHTTP: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !isHTTP {
|
|
||||||
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
|
|
||||||
if isWebSocket(h2) {
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
|
|
||||||
} else {
|
|
||||||
r.Cache.Add(destAddr, destAddr)
|
|
||||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req, err = http.ReadRequest(reader); err != nil {
|
|
||||||
err = fmt.Errorf("http.ReadRequest: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if pass, err = r.RewriteAndForward(dst, req, destAddr, srcAddr); pass {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
|
||||||
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
|
|
||||||
// Fast check: peek first word to see if it's a known HTTP method
|
|
||||||
const maxMethodLen = 7
|
|
||||||
hintSlice, err := reader.Peek(maxMethodLen)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
hint := string(hintSlice)
|
|
||||||
method, _, _ := strings.Cut(hint, " ")
|
|
||||||
_, exists := httpMethods[method]
|
|
||||||
return exists, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildNewUA returns either a partial replacement (regex) or full overwrite.
|
|
||||||
func (r *Rewriter) buildNewUA(originUA string) string {
|
|
||||||
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
|
|
||||||
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
|
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
|
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
|
||||||
@ -237,131 +76,117 @@ func (r *Rewriter) buildNewUA(originUA string) string {
|
|||||||
return r.payloadUA
|
return r.payloadUA
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rewriter) inWhitelist(ua string) bool {
|
func (r *Rewriter) ShouldRewrite(req *http.Request, srcAddr, destAddr string) bool {
|
||||||
_, ok := r.whitelist[ua]
|
originalUA := req.Header.Get("User-Agent")
|
||||||
return ok
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: (%s)", originalUA))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
matches := true
|
||||||
|
isWhitelist := r.inWhitelist(originalUA)
|
||||||
|
|
||||||
|
if !isWhitelist && r.pattern != "" {
|
||||||
|
matches, err = r.uaRegex.MatchString(originalUA)
|
||||||
|
if err != nil {
|
||||||
|
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Match Error: %s", err.Error()))
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isWhitelist {
|
||||||
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
||||||
|
r.Cache.Add(destAddr, struct{}{})
|
||||||
|
}
|
||||||
|
if !matches {
|
||||||
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
||||||
|
}
|
||||||
|
|
||||||
|
hit := !isWhitelist && matches
|
||||||
|
if !hit {
|
||||||
|
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
||||||
|
Host: destAddr,
|
||||||
|
UA: originalUA,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return hit
|
||||||
}
|
}
|
||||||
|
|
||||||
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
func (r *Rewriter) Rewrite(req *http.Request, srcAddr string, destAddr string) *http.Request {
|
||||||
// returns the line bytes (without CRLF) or error.
|
originalUA := req.Header.Get("User-Agent")
|
||||||
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
|
rewritedUA := r.buildUserAgent(originalUA)
|
||||||
const chunkSize = 256
|
req.Header.Set("User-Agent", rewritedUA)
|
||||||
var line []byte
|
|
||||||
|
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
|
||||||
|
|
||||||
|
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
||||||
|
Host: destAddr,
|
||||||
|
OriginalUA: originalUA,
|
||||||
|
MockedUA: rewritedUA,
|
||||||
|
})
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rewriter) Forward(dst net.Conn, req *http.Request) error {
|
||||||
|
if err := req.Write(dst); err != nil {
|
||||||
|
return fmt.Errorf("req.Write: %w", err)
|
||||||
|
}
|
||||||
|
req.Body.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process handles the proxying with UA rewriting logic.
|
||||||
|
func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
|
||||||
|
reader := bufio.NewReader(src)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Process: %s", err.Error()))
|
||||||
|
}
|
||||||
|
io.Copy(dst, reader)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) {
|
||||||
|
r.Cache.Add(destAddr, struct{}{})
|
||||||
|
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var isHTTP bool
|
||||||
|
|
||||||
|
if isHTTP, err = sniff.SniffHTTP(reader); err != nil {
|
||||||
|
err = fmt.Errorf("sniff.SniffHTTP: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isHTTP {
|
||||||
|
r.Cache.Add(destAddr, struct{}{})
|
||||||
|
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to cache")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req *http.Request
|
||||||
|
|
||||||
offset := 0
|
|
||||||
for {
|
for {
|
||||||
// Ensure there is data in the buffer
|
if isHTTP, err = sniff.SniffHTTPFast(reader); err != nil {
|
||||||
n := br.Buffered()
|
err = fmt.Errorf("isHTTP: %w", err)
|
||||||
if n == 0 {
|
return
|
||||||
// No data in buffer, try to fill it
|
|
||||||
_, err := br.Peek(1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n = br.Buffered()
|
|
||||||
}
|
}
|
||||||
|
if !isHTTP {
|
||||||
// Limit to chunkSize
|
r.Cache.Add(destAddr, struct{}{})
|
||||||
if n > chunkSize {
|
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||||
n = chunkSize
|
return
|
||||||
}
|
}
|
||||||
|
if req, err = http.ReadRequest(reader); err != nil {
|
||||||
buf, err := br.Peek(offset + n)
|
err = fmt.Errorf("http.ReadRequest: %w", err)
|
||||||
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
|
return
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
if r.ShouldRewrite(req, srcAddr, destAddr) {
|
||||||
data := buf[offset:]
|
req = r.Rewrite(req, srcAddr, destAddr)
|
||||||
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
|
||||||
line = append(line, data[:i]...)
|
|
||||||
// Remove trailing CR if present
|
|
||||||
if len(line) > 0 && line[len(line)-1] == '\r' {
|
|
||||||
line = line[:len(line)-1]
|
|
||||||
}
|
|
||||||
return line, nil
|
|
||||||
}
|
}
|
||||||
|
if err = r.Forward(dst, req); err != nil {
|
||||||
line = append(line, data...)
|
err = fmt.Errorf("r.forward: %w", err)
|
||||||
offset += len(data)
|
return
|
||||||
|
}
|
||||||
// If EOF reached without finding a newline, return what we have
|
if req.Header.Get("Upgrade") == "websocket" && req.Header.Get("Connection") == "Upgrade" {
|
||||||
if errors.Is(err, io.EOF) {
|
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket Upgrade detected, switching to raw proxy")
|
||||||
return line, io.EOF
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// peekLineString reads a line from bufio.Reader without consuming it.
|
|
||||||
// returns the line string (without CRLF) or error.
|
|
||||||
func peekLineString(br *bufio.Reader) (string, error) {
|
|
||||||
lineBytes, err := peekLineSlice(br)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(lineBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
|
||||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
|
||||||
method, rest, ok1 := strings.Cut(line, " ")
|
|
||||||
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
|
||||||
if !ok1 || !ok2 {
|
|
||||||
return "", "", "", false
|
|
||||||
}
|
|
||||||
return method, requestURI, proto, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWebSocket(header []byte) bool {
|
|
||||||
if len(header) < 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
b0 := header[0]
|
|
||||||
b1 := header[1]
|
|
||||||
|
|
||||||
rsv := b0 & 0x70 // RSV1-3
|
|
||||||
opcode := b0 & 0x0F // opcode
|
|
||||||
mask := b1 & 0x80 // MASK
|
|
||||||
|
|
||||||
// requested frames from client to server must be masked
|
|
||||||
if mask == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Control frames must have FIN set
|
|
||||||
if rsv != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// opcode must be in valid range
|
|
||||||
if opcode > 0xA {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// payload length
|
|
||||||
payloadLen := b1 & 0x7F
|
|
||||||
if payloadLen > 0 && payloadLen <= 125 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTLSClientHello(reader *bufio.Reader) bool {
|
|
||||||
header, err := reader.Peek(3)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// TLS record type 0x16 = Handshake
|
|
||||||
if header[0] != 0x16 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// TLS version
|
|
||||||
versionMajor := header[1]
|
|
||||||
versionMinor := header[2]
|
|
||||||
if versionMajor != 0x03 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if versionMinor < 0x01 || versionMinor > 0x04 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -46,35 +45,11 @@ func TestNewRewriter(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, cfg.PayloadUA, rewriter.payloadUA)
|
assert.Equal(t, cfg.PayloadUA, rewriter.payloadUA)
|
||||||
assert.Equal(t, cfg.UAPattern, rewriter.pattern)
|
assert.Equal(t, cfg.UAPattern, rewriter.pattern)
|
||||||
assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace)
|
assert.Equal(t, cfg.EnablePartialReplace, rewriter.partialReplace)
|
||||||
assert.NotNil(t, rewriter.uaRegex)
|
assert.NotNil(t, rewriter.uaRegex)
|
||||||
assert.NotNil(t, rewriter.Cache)
|
assert.NotNil(t, rewriter.Cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsHTTP(t *testing.T) {
|
|
||||||
r := newTestRewriter(t)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"HTTP Get", "GET / HTTP/1.1\r\n", true},
|
|
||||||
{"HTTP Post", "POST /test HTTP/1.1\r\n", true},
|
|
||||||
{"HTTP Connect", "CONNECT example.com:443 HTTP/1.1\r\n", true},
|
|
||||||
{"Not HTTP", "HELLO WORLD\r\n", false},
|
|
||||||
{"Not HTTP", "SSH-2.0-OpenSSH_8.4\r\n", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
reader := bufio.NewReader(strings.NewReader(tt.input))
|
|
||||||
isHTTP, _ := r.isHTTP(reader)
|
|
||||||
assert.Equal(t, tt.expected, isHTTP)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
|
func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
|
||||||
r := newTestRewriter(t)
|
r := newTestRewriter(t)
|
||||||
|
|
||||||
@ -83,7 +58,7 @@ func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
|
|||||||
dstBuf := &bytes.Buffer{}
|
dstBuf := &bytes.Buffer{}
|
||||||
dst := &mockConn{Reader: nil, Writer: dstBuf}
|
dst := &mockConn{Reader: nil, Writer: dstBuf}
|
||||||
|
|
||||||
r.ProxyHTTPOrRaw(dst, src, "example.com:80", "srcAddr")
|
r.Process(dst, src, "example.com:80", "srcAddr")
|
||||||
|
|
||||||
out := dstBuf.String()
|
out := dstBuf.String()
|
||||||
assert.Contains(t, out, "User-Agent: MockUA/1.0")
|
assert.Contains(t, out, "User-Agent: MockUA/1.0")
|
||||||
|
|||||||
@ -58,7 +58,7 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
defer target.Close()
|
defer target.Close()
|
||||||
|
|
||||||
_, err = s.rw.RewriteAndForward(target, req, req.Host, req.RemoteAddr)
|
err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
@ -101,6 +101,18 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
|
|||||||
s.ForwardTCP(client, dest, destAddr)
|
s.ForwardTCP(client, dest, destAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) {
|
||||||
|
rw := s.rw
|
||||||
|
if rw.ShouldRewrite(req, srcAddr, dstAddr) {
|
||||||
|
req = rw.Rewrite(req, srcAddr, dstAddr)
|
||||||
|
}
|
||||||
|
if err = rw.Forward(target, req); err != nil {
|
||||||
|
err = fmt.Errorf("r.forward: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) HandleClient(client net.Conn) {
|
func (s *Server) HandleClient(client net.Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddr string) {
|
|||||||
io.Copy(dst, src)
|
io.Copy(dst, src)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_ = rw.ProxyHTTPOrRaw(dst, src, destAddr, srcAddr)
|
_ = rw.Process(dst, src, destAddr, srcAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetConnFD(conn net.Conn) (fd int, err error) {
|
func GetConnFD(conn net.Conn) (fd int, err error) {
|
||||||
|
|||||||
75
src/internal/sniff/common.go
Normal file
75
src/internal/sniff/common.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package sniff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrPeekTimeout = errors.New("peek timeout")
|
||||||
|
|
||||||
|
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
||||||
|
// returns the line bytes (without CRLF) or error.
|
||||||
|
func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) {
|
||||||
|
var line []byte
|
||||||
|
|
||||||
|
peekSize := maxSize
|
||||||
|
if peekSize == 0 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
if buffered := br.Buffered(); buffered < peekSize {
|
||||||
|
peekSize = buffered
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := br.Peek(peekSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := bytes.IndexByte(buf, '\n'); i >= 0 {
|
||||||
|
line = append(line, buf[:i]...)
|
||||||
|
// Remove trailing CR if present
|
||||||
|
if len(line) > 0 && line[len(line)-1] == '\r' {
|
||||||
|
line = line[:len(line)-1]
|
||||||
|
}
|
||||||
|
return line, nil
|
||||||
|
}
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// peekLineString reads a line from bufio.Reader without consuming it.
|
||||||
|
// returns the line string (without CRLF) or error.
|
||||||
|
func peekLineString(br *bufio.Reader, maxSize int) (string, error) {
|
||||||
|
lineBytes, err := peekLineSlice(br, maxSize)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(lineBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekWithTimeout peeks n bytes from bufio.Reader with a timeout.
|
||||||
|
func PeekWithTimeout(r *bufio.Reader, n int, timeout time.Duration) ([]byte, error) {
|
||||||
|
if buffered := r.Buffered(); buffered >= n {
|
||||||
|
data, err := r.Peek(n)
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
type result struct {
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
data, err := r.Peek(n)
|
||||||
|
ch <- result{data, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-ch:
|
||||||
|
return res.data, res.err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, ErrPeekTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
75
src/internal/sniff/http.go
Normal file
75
src/internal/sniff/http.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package sniff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP methods used to detect HTTP by request line.
|
||||||
|
var methods = [...]string{"GET", "POST", "HEAD", "CONNECT", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE"}
|
||||||
|
|
||||||
|
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||||
|
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||||
|
method, rest, ok1 := strings.Cut(line, " ")
|
||||||
|
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
return method, requestURI, proto, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// beginWithHTTPMethod peeks the first few bytes to check for known HTTP method prefixes.
|
||||||
|
func beginWithHTTPMethod(reader *bufio.Reader) (bool, error) {
|
||||||
|
const maxMethodLen = 7
|
||||||
|
const minMethodLen = 3
|
||||||
|
var hint []byte
|
||||||
|
hint, err := PeekWithTimeout(reader, maxMethodLen, 3*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
if err != ErrPeekTimeout {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
hint, err = PeekWithTimeout(reader, minMethodLen+1, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
method, _, _ := strings.Cut(string(hint), " ")
|
||||||
|
for _, m := range methods {
|
||||||
|
if method == m {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SniffHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
||||||
|
func SniffHTTP(reader *bufio.Reader) (bool, error) {
|
||||||
|
// Fast check: peek first word to see if it's a known HTTP method
|
||||||
|
beginHTTP, err := beginWithHTTPMethod(reader)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detailed check: parse request line
|
||||||
|
line, err := peekLineString(reader, 128)
|
||||||
|
if err != nil {
|
||||||
|
return beginHTTP, nil
|
||||||
|
}
|
||||||
|
_, _, proto, ok := parseRequestLine(line)
|
||||||
|
if !ok {
|
||||||
|
return beginHTTP, nil
|
||||||
|
}
|
||||||
|
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return beginHTTP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SniffHTTPFast(reader *bufio.Reader) (bool, error) {
|
||||||
|
beginHTTP, err := beginWithHTTPMethod(reader)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return beginHTTP, nil
|
||||||
|
}
|
||||||
25
src/internal/sniff/tls.go
Normal file
25
src/internal/sniff/tls.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package sniff
|
||||||
|
|
||||||
|
import "bufio"
|
||||||
|
|
||||||
|
func SniffTLSClientHello(reader *bufio.Reader) bool {
|
||||||
|
header, err := reader.Peek(3)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// TLS record type 0x16 = Handshake
|
||||||
|
if header[0] != 0x16 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// TLS version
|
||||||
|
versionMajor := header[1]
|
||||||
|
versionMinor := header[2]
|
||||||
|
if versionMajor != 0x03 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if versionMinor < 0x01 || versionMinor > 0x04 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
34
src/internal/sniff/websocket.go
Normal file
34
src/internal/sniff/websocket.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package sniff
|
||||||
|
|
||||||
|
func SniffWebSocket(header []byte) bool {
|
||||||
|
if len(header) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
b0 := header[0]
|
||||||
|
b1 := header[1]
|
||||||
|
|
||||||
|
rsv := b0 & 0x70 // RSV1-3
|
||||||
|
opcode := b0 & 0x0F // opcode
|
||||||
|
mask := b1 & 0x80 // MASK
|
||||||
|
|
||||||
|
// requested frames from client to server must be masked
|
||||||
|
if mask == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Control frames must have FIN set
|
||||||
|
if rsv != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// opcode must be in valid range
|
||||||
|
if opcode > 0xA {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// payload length
|
||||||
|
payloadLen := b1 & 0x7F
|
||||||
|
if payloadLen > 0 && payloadLen <= 125 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user