mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 08:44:29 +00:00
refactor: streamline rewriting
This commit is contained in:
parent
3e2e3aa7e7
commit
0a01c9f04f
@ -2,8 +2,6 @@ package rewrite
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -17,44 +15,19 @@ import (
|
||||
|
||||
"github.com/sunbk201/ua3f/internal/config"
|
||||
"github.com/sunbk201/ua3f/internal/log"
|
||||
"github.com/sunbk201/ua3f/internal/sniff"
|
||||
"github.com/sunbk201/ua3f/internal/statistics"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrUseClosedConn = "use of closed network connection"
|
||||
ErrConnResetByPeer = "connection reset by peer"
|
||||
ErrIOTimeout = "i/o timeout"
|
||||
)
|
||||
|
||||
// HTTP methods used to detect HTTP by request line.
|
||||
var httpMethods = map[string]struct{}{
|
||||
"GET": {},
|
||||
"POST": {},
|
||||
"HEAD": {},
|
||||
"PUT": {},
|
||||
"PATCH": {},
|
||||
"DELETE": {},
|
||||
"OPTIONS": {},
|
||||
"TRACE": {},
|
||||
"CONNECT": {},
|
||||
}
|
||||
|
||||
// Hardcoded whitelist of UAs that should be left untouched.
|
||||
var defaultWhitelist = []string{
|
||||
"MicroMessenger Client",
|
||||
"ByteDancePcdn",
|
||||
"Go-http-client/1.1",
|
||||
}
|
||||
|
||||
// Rewriter encapsulates HTTP UA rewrite behavior and pass-through cache.
|
||||
type Rewriter struct {
|
||||
payloadUA string
|
||||
pattern string
|
||||
enablePartialReplace bool
|
||||
payloadUA string
|
||||
pattern string
|
||||
partialReplace bool
|
||||
|
||||
uaRegex *regexp2.Regexp
|
||||
Cache *expirable.LRU[string, string]
|
||||
whitelist map[string]struct{}
|
||||
whitelist []string
|
||||
Cache *expirable.LRU[string, struct{}]
|
||||
}
|
||||
|
||||
// New constructs a Rewriter from config. Compiles regex and allocates cache.
|
||||
@ -66,167 +39,33 @@ func New(cfg *config.Config) (*Rewriter, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := expirable.NewLRU[string, string](1024, nil, 10*time.Minute)
|
||||
|
||||
whitelist := make(map[string]struct{}, len(defaultWhitelist))
|
||||
for _, s := range defaultWhitelist {
|
||||
whitelist[s] = struct{}{}
|
||||
}
|
||||
|
||||
return &Rewriter{
|
||||
payloadUA: cfg.PayloadUA,
|
||||
pattern: cfg.UAPattern,
|
||||
enablePartialReplace: cfg.EnablePartialReplace,
|
||||
uaRegex: uaRegex,
|
||||
Cache: cache,
|
||||
whitelist: whitelist,
|
||||
payloadUA: cfg.PayloadUA,
|
||||
pattern: cfg.UAPattern,
|
||||
partialReplace: cfg.EnablePartialReplace,
|
||||
uaRegex: uaRegex,
|
||||
Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute),
|
||||
whitelist: []string{
|
||||
"MicroMessenger Client",
|
||||
"Bilibili Freedoooooom/MarkII",
|
||||
"Go-http-client/1.1",
|
||||
"ByteDancePcdn",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RewriteAndForward rewrites the User-Agent header if needed and forwards the request.
|
||||
// Returns pass=true if the request has been forwarded as-is (no rewrite).
|
||||
func (r *Rewriter) RewriteAndForward(dst net.Conn, req *http.Request, destAddr string, srcAddr string) (pass bool, err error) {
|
||||
|
||||
originalUA := req.Header.Get("User-Agent")
|
||||
|
||||
// No UA header: pass-through after writing this first request
|
||||
if originalUA == "" {
|
||||
r.Cache.Add(destAddr, destAddr)
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
}
|
||||
pass = true
|
||||
return
|
||||
}
|
||||
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: %s", originalUA))
|
||||
|
||||
isWhitelist := r.inWhitelist(originalUA)
|
||||
matches := true
|
||||
if r.pattern != "" {
|
||||
matches, err = r.uaRegex.MatchString(originalUA)
|
||||
if err != nil {
|
||||
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Pattern Match Error: %s", err.Error()))
|
||||
matches = true
|
||||
func (r *Rewriter) inWhitelist(ua string) bool {
|
||||
for _, w := range r.whitelist {
|
||||
if w == ua {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// If UA is whitelisted or does not match target pattern, write once then pass-through.
|
||||
if isWhitelist || !matches {
|
||||
if !matches {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
||||
}
|
||||
if isWhitelist {
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
||||
r.Cache.Add(destAddr, destAddr)
|
||||
}
|
||||
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
||||
Host: destAddr,
|
||||
UA: originalUA,
|
||||
})
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
}
|
||||
pass = true
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite UA and forward the request (including body)
|
||||
rewritedUA := r.buildNewUA(originalUA)
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
|
||||
req.Header.Set("User-Agent", rewritedUA)
|
||||
if err = req.Write(dst); err != nil {
|
||||
err = fmt.Errorf("req.Write: %w", err)
|
||||
pass = true
|
||||
return
|
||||
}
|
||||
|
||||
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
||||
Host: destAddr,
|
||||
OriginalUA: originalUA,
|
||||
MockedUA: rewritedUA,
|
||||
})
|
||||
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
// ProxyHTTPOrRaw reads traffic from src and writes to dst.
|
||||
// - If target in LRU cache: pass-through (raw).
|
||||
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
|
||||
// - Else: mark target in LRU and pass-through.
|
||||
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
|
||||
reader := bufio.NewReader(src)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
|
||||
}
|
||||
io.Copy(dst, reader)
|
||||
}()
|
||||
|
||||
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
|
||||
r.Cache.Add(destAddr, destAddr)
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
|
||||
return
|
||||
}
|
||||
isHTTP, err := r.isHTTP(reader)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("isHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
if !isHTTP {
|
||||
r.Cache.Add(destAddr, destAddr)
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||
return
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
var pass bool
|
||||
|
||||
// HTTP request loop (handles keep-alive)
|
||||
for {
|
||||
if isHTTP, err = r.isHTTP(reader); err != nil {
|
||||
err = fmt.Errorf("isHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
if !isHTTP {
|
||||
h2, _ := reader.Peek(2) // ensure we have at least 2 bytes
|
||||
if isWebSocket(h2) {
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
|
||||
} else {
|
||||
r.Cache.Add(destAddr, destAddr)
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||
}
|
||||
return
|
||||
}
|
||||
if req, err = http.ReadRequest(reader); err != nil {
|
||||
err = fmt.Errorf("http.ReadRequest: %w", err)
|
||||
return
|
||||
}
|
||||
if pass, err = r.RewriteAndForward(dst, req, destAddr, srcAddr); pass {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
||||
func (r *Rewriter) isHTTP(reader *bufio.Reader) (bool, error) {
|
||||
// Fast check: peek first word to see if it's a known HTTP method
|
||||
const maxMethodLen = 7
|
||||
hintSlice, err := reader.Peek(maxMethodLen)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
hint := string(hintSlice)
|
||||
method, _, _ := strings.Cut(hint, " ")
|
||||
_, exists := httpMethods[method]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// buildNewUA returns either a partial replacement (regex) or full overwrite.
|
||||
func (r *Rewriter) buildNewUA(originUA string) string {
|
||||
if r.enablePartialReplace && r.uaRegex != nil && r.pattern != "" {
|
||||
// buildUserAgent returns either a partial replacement (regex) or full overwrite.
|
||||
func (r *Rewriter) buildUserAgent(originUA string) string {
|
||||
if r.partialReplace && r.uaRegex != nil && r.pattern != "" {
|
||||
newUA, err := r.uaRegex.Replace(originUA, r.payloadUA, -1, -1)
|
||||
if err != nil {
|
||||
logrus.Errorf("User-Agent Replace Error: %s, use full overwrite", err.Error())
|
||||
@ -237,131 +76,117 @@ func (r *Rewriter) buildNewUA(originUA string) string {
|
||||
return r.payloadUA
|
||||
}
|
||||
|
||||
func (r *Rewriter) inWhitelist(ua string) bool {
|
||||
_, ok := r.whitelist[ua]
|
||||
return ok
|
||||
func (r *Rewriter) ShouldRewrite(req *http.Request, srcAddr, destAddr string) bool {
|
||||
originalUA := req.Header.Get("User-Agent")
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Original User-Agent: (%s)", originalUA))
|
||||
|
||||
var err error
|
||||
matches := true
|
||||
isWhitelist := r.inWhitelist(originalUA)
|
||||
|
||||
if !isWhitelist && r.pattern != "" {
|
||||
matches, err = r.uaRegex.MatchString(originalUA)
|
||||
if err != nil {
|
||||
log.LogErrorWithAddr(srcAddr, destAddr, fmt.Sprintf("User-Agent Regex Match Error: %s", err.Error()))
|
||||
matches = true
|
||||
}
|
||||
}
|
||||
if isWhitelist {
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
|
||||
r.Cache.Add(destAddr, struct{}{})
|
||||
}
|
||||
if !matches {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not Hit User-Agent Regex: %s", originalUA))
|
||||
}
|
||||
|
||||
hit := !isWhitelist && matches
|
||||
if !hit {
|
||||
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
|
||||
Host: destAddr,
|
||||
UA: originalUA,
|
||||
})
|
||||
}
|
||||
return hit
|
||||
}
|
||||
|
||||
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
||||
// returns the line bytes (without CRLF) or error.
|
||||
func peekLineSlice(br *bufio.Reader) ([]byte, error) {
|
||||
const chunkSize = 256
|
||||
var line []byte
|
||||
func (r *Rewriter) Rewrite(req *http.Request, srcAddr string, destAddr string) *http.Request {
|
||||
originalUA := req.Header.Get("User-Agent")
|
||||
rewritedUA := r.buildUserAgent(originalUA)
|
||||
req.Header.Set("User-Agent", rewritedUA)
|
||||
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite User-Agent from (%s) to (%s)", originalUA, rewritedUA))
|
||||
|
||||
statistics.AddRewriteRecord(&statistics.RewriteRecord{
|
||||
Host: destAddr,
|
||||
OriginalUA: originalUA,
|
||||
MockedUA: rewritedUA,
|
||||
})
|
||||
return req
|
||||
}
|
||||
|
||||
func (r *Rewriter) Forward(dst net.Conn, req *http.Request) error {
|
||||
if err := req.Write(dst); err != nil {
|
||||
return fmt.Errorf("req.Write: %w", err)
|
||||
}
|
||||
req.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process handles the proxying with UA rewriting logic.
|
||||
func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
|
||||
reader := bufio.NewReader(src)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Process: %s", err.Error()))
|
||||
}
|
||||
io.Copy(dst, reader)
|
||||
}()
|
||||
|
||||
if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) {
|
||||
r.Cache.Add(destAddr, struct{}{})
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, pass forward")
|
||||
return
|
||||
}
|
||||
|
||||
var isHTTP bool
|
||||
|
||||
if isHTTP, err = sniff.SniffHTTP(reader); err != nil {
|
||||
err = fmt.Errorf("sniff.SniffHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
if !isHTTP {
|
||||
r.Cache.Add(destAddr, struct{}{})
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to cache")
|
||||
return
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
|
||||
offset := 0
|
||||
for {
|
||||
// Ensure there is data in the buffer
|
||||
n := br.Buffered()
|
||||
if n == 0 {
|
||||
// No data in buffer, try to fill it
|
||||
_, err := br.Peek(1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n = br.Buffered()
|
||||
if isHTTP, err = sniff.SniffHTTPFast(reader); err != nil {
|
||||
err = fmt.Errorf("isHTTP: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to chunkSize
|
||||
if n > chunkSize {
|
||||
n = chunkSize
|
||||
if !isHTTP {
|
||||
r.Cache.Add(destAddr, struct{}{})
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
|
||||
return
|
||||
}
|
||||
|
||||
buf, err := br.Peek(offset + n)
|
||||
if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
if req, err = http.ReadRequest(reader); err != nil {
|
||||
err = fmt.Errorf("http.ReadRequest: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
data := buf[offset:]
|
||||
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
||||
line = append(line, data[:i]...)
|
||||
// Remove trailing CR if present
|
||||
if len(line) > 0 && line[len(line)-1] == '\r' {
|
||||
line = line[:len(line)-1]
|
||||
}
|
||||
return line, nil
|
||||
if r.ShouldRewrite(req, srcAddr, destAddr) {
|
||||
req = r.Rewrite(req, srcAddr, destAddr)
|
||||
}
|
||||
|
||||
line = append(line, data...)
|
||||
offset += len(data)
|
||||
|
||||
// If EOF reached without finding a newline, return what we have
|
||||
if errors.Is(err, io.EOF) {
|
||||
return line, io.EOF
|
||||
if err = r.Forward(dst, req); err != nil {
|
||||
err = fmt.Errorf("r.forward: %w", err)
|
||||
return
|
||||
}
|
||||
if req.Header.Get("Upgrade") == "websocket" && req.Header.Get("Connection") == "Upgrade" {
|
||||
log.LogInfoWithAddr(srcAddr, destAddr, "WebSocket Upgrade detected, switching to raw proxy")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// peekLineString reads a line from bufio.Reader without consuming it.
|
||||
// returns the line string (without CRLF) or error.
|
||||
func peekLineString(br *bufio.Reader) (string, error) {
|
||||
lineBytes, err := peekLineSlice(br)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(lineBytes), nil
|
||||
}
|
||||
|
||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||
method, rest, ok1 := strings.Cut(line, " ")
|
||||
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
||||
if !ok1 || !ok2 {
|
||||
return "", "", "", false
|
||||
}
|
||||
return method, requestURI, proto, true
|
||||
}
|
||||
|
||||
func isWebSocket(header []byte) bool {
|
||||
if len(header) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
b0 := header[0]
|
||||
b1 := header[1]
|
||||
|
||||
rsv := b0 & 0x70 // RSV1-3
|
||||
opcode := b0 & 0x0F // opcode
|
||||
mask := b1 & 0x80 // MASK
|
||||
|
||||
// requested frames from client to server must be masked
|
||||
if mask == 0 {
|
||||
return false
|
||||
}
|
||||
// Control frames must have FIN set
|
||||
if rsv != 0 {
|
||||
return false
|
||||
}
|
||||
// opcode must be in valid range
|
||||
if opcode > 0xA {
|
||||
return false
|
||||
}
|
||||
// payload length
|
||||
payloadLen := b1 & 0x7F
|
||||
if payloadLen > 0 && payloadLen <= 125 {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isTLSClientHello(reader *bufio.Reader) bool {
|
||||
header, err := reader.Peek(3)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// TLS record type 0x16 = Handshake
|
||||
if header[0] != 0x16 {
|
||||
return false
|
||||
}
|
||||
// TLS version
|
||||
versionMajor := header[1]
|
||||
versionMinor := header[2]
|
||||
if versionMajor != 0x03 {
|
||||
return false
|
||||
}
|
||||
if versionMinor < 0x01 || versionMinor > 0x04 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
@ -46,35 +45,11 @@ func TestNewRewriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cfg.PayloadUA, rewriter.payloadUA)
|
||||
assert.Equal(t, cfg.UAPattern, rewriter.pattern)
|
||||
assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace)
|
||||
assert.Equal(t, cfg.EnablePartialReplace, rewriter.partialReplace)
|
||||
assert.NotNil(t, rewriter.uaRegex)
|
||||
assert.NotNil(t, rewriter.Cache)
|
||||
}
|
||||
|
||||
func TestIsHTTP(t *testing.T) {
|
||||
r := newTestRewriter(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"HTTP Get", "GET / HTTP/1.1\r\n", true},
|
||||
{"HTTP Post", "POST /test HTTP/1.1\r\n", true},
|
||||
{"HTTP Connect", "CONNECT example.com:443 HTTP/1.1\r\n", true},
|
||||
{"Not HTTP", "HELLO WORLD\r\n", false},
|
||||
{"Not HTTP", "SSH-2.0-OpenSSH_8.4\r\n", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := bufio.NewReader(strings.NewReader(tt.input))
|
||||
isHTTP, _ := r.isHTTP(reader)
|
||||
assert.Equal(t, tt.expected, isHTTP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
|
||||
r := newTestRewriter(t)
|
||||
|
||||
@ -83,7 +58,7 @@ func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
|
||||
dstBuf := &bytes.Buffer{}
|
||||
dst := &mockConn{Reader: nil, Writer: dstBuf}
|
||||
|
||||
r.ProxyHTTPOrRaw(dst, src, "example.com:80", "srcAddr")
|
||||
r.Process(dst, src, "example.com:80", "srcAddr")
|
||||
|
||||
out := dstBuf.String()
|
||||
assert.Contains(t, out, "User-Agent: MockUA/1.0")
|
||||
|
||||
@ -58,7 +58,7 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
_, err = s.rw.RewriteAndForward(target, req, req.Host, req.RemoteAddr)
|
||||
err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
@ -101,6 +101,18 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
|
||||
s.ForwardTCP(client, dest, destAddr)
|
||||
}
|
||||
|
||||
func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) {
|
||||
rw := s.rw
|
||||
if rw.ShouldRewrite(req, srcAddr, dstAddr) {
|
||||
req = rw.Rewrite(req, srcAddr, dstAddr)
|
||||
}
|
||||
if err = rw.Forward(target, req); err != nil {
|
||||
err = fmt.Errorf("r.forward: %w", err)
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) HandleClient(client net.Conn) {
|
||||
}
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddr string) {
|
||||
io.Copy(dst, src)
|
||||
return
|
||||
}
|
||||
_ = rw.ProxyHTTPOrRaw(dst, src, destAddr, srcAddr)
|
||||
_ = rw.Process(dst, src, destAddr, srcAddr)
|
||||
}
|
||||
|
||||
func GetConnFD(conn net.Conn) (fd int, err error) {
|
||||
|
||||
75
src/internal/sniff/common.go
Normal file
75
src/internal/sniff/common.go
Normal file
@ -0,0 +1,75 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrPeekTimeout = errors.New("peek timeout")
|
||||
|
||||
// peekLineSlice reads a line from bufio.Reader without consuming it.
|
||||
// returns the line bytes (without CRLF) or error.
|
||||
func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) {
|
||||
var line []byte
|
||||
|
||||
peekSize := maxSize
|
||||
if peekSize == 0 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if buffered := br.Buffered(); buffered < peekSize {
|
||||
peekSize = buffered
|
||||
}
|
||||
|
||||
buf, err := br.Peek(peekSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if i := bytes.IndexByte(buf, '\n'); i >= 0 {
|
||||
line = append(line, buf[:i]...)
|
||||
// Remove trailing CR if present
|
||||
if len(line) > 0 && line[len(line)-1] == '\r' {
|
||||
line = line[:len(line)-1]
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// peekLineString reads a line from bufio.Reader without consuming it.
|
||||
// returns the line string (without CRLF) or error.
|
||||
func peekLineString(br *bufio.Reader, maxSize int) (string, error) {
|
||||
lineBytes, err := peekLineSlice(br, maxSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(lineBytes), nil
|
||||
}
|
||||
|
||||
// PeekWithTimeout peeks n bytes from bufio.Reader with a timeout.
|
||||
func PeekWithTimeout(r *bufio.Reader, n int, timeout time.Duration) ([]byte, error) {
|
||||
if buffered := r.Buffered(); buffered >= n {
|
||||
data, err := r.Peek(n)
|
||||
return data, err
|
||||
}
|
||||
type result struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
data, err := r.Peek(n)
|
||||
ch <- result{data, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-ch:
|
||||
return res.data, res.err
|
||||
case <-time.After(timeout):
|
||||
return nil, ErrPeekTimeout
|
||||
}
|
||||
}
|
||||
75
src/internal/sniff/http.go
Normal file
75
src/internal/sniff/http.go
Normal file
@ -0,0 +1,75 @@
|
||||
package sniff
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP methods used to detect HTTP by request line.
|
||||
var methods = [...]string{"GET", "POST", "HEAD", "CONNECT", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE"}
|
||||
|
||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||
method, rest, ok1 := strings.Cut(line, " ")
|
||||
requestURI, proto, ok2 := strings.Cut(rest, " ")
|
||||
if !ok1 || !ok2 {
|
||||
return "", "", "", false
|
||||
}
|
||||
return method, requestURI, proto, true
|
||||
}
|
||||
|
||||
// beginWithHTTPMethod peeks the first few bytes to check for known HTTP method prefixes.
|
||||
func beginWithHTTPMethod(reader *bufio.Reader) (bool, error) {
|
||||
const maxMethodLen = 7
|
||||
const minMethodLen = 3
|
||||
var hint []byte
|
||||
hint, err := PeekWithTimeout(reader, maxMethodLen, 3*time.Second)
|
||||
if err != nil {
|
||||
if err != ErrPeekTimeout {
|
||||
return false, err
|
||||
}
|
||||
hint, err = PeekWithTimeout(reader, minMethodLen+1, time.Second)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
method, _, _ := strings.Cut(string(hint), " ")
|
||||
for _, m := range methods {
|
||||
if method == m {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// SniffHTTP peeks the first few bytes and checks for a known HTTP method prefix.
|
||||
func SniffHTTP(reader *bufio.Reader) (bool, error) {
|
||||
// Fast check: peek first word to see if it's a known HTTP method
|
||||
beginHTTP, err := beginWithHTTPMethod(reader)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Detailed check: parse request line
|
||||
line, err := peekLineString(reader, 128)
|
||||
if err != nil {
|
||||
return beginHTTP, nil
|
||||
}
|
||||
_, _, proto, ok := parseRequestLine(line)
|
||||
if !ok {
|
||||
return beginHTTP, nil
|
||||
}
|
||||
if proto != "HTTP/1.1" && proto != "HTTP/1.0" {
|
||||
return false, nil
|
||||
}
|
||||
return beginHTTP, nil
|
||||
}
|
||||
|
||||
func SniffHTTPFast(reader *bufio.Reader) (bool, error) {
|
||||
beginHTTP, err := beginWithHTTPMethod(reader)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return beginHTTP, nil
|
||||
}
|
||||
25
src/internal/sniff/tls.go
Normal file
25
src/internal/sniff/tls.go
Normal file
@ -0,0 +1,25 @@
|
||||
package sniff
|
||||
|
||||
import "bufio"
|
||||
|
||||
func SniffTLSClientHello(reader *bufio.Reader) bool {
|
||||
header, err := reader.Peek(3)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// TLS record type 0x16 = Handshake
|
||||
if header[0] != 0x16 {
|
||||
return false
|
||||
}
|
||||
// TLS version
|
||||
versionMajor := header[1]
|
||||
versionMinor := header[2]
|
||||
if versionMajor != 0x03 {
|
||||
return false
|
||||
}
|
||||
if versionMinor < 0x01 || versionMinor > 0x04 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
34
src/internal/sniff/websocket.go
Normal file
34
src/internal/sniff/websocket.go
Normal file
@ -0,0 +1,34 @@
|
||||
package sniff
|
||||
|
||||
func SniffWebSocket(header []byte) bool {
|
||||
if len(header) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
b0 := header[0]
|
||||
b1 := header[1]
|
||||
|
||||
rsv := b0 & 0x70 // RSV1-3
|
||||
opcode := b0 & 0x0F // opcode
|
||||
mask := b1 & 0x80 // MASK
|
||||
|
||||
// requested frames from client to server must be masked
|
||||
if mask == 0 {
|
||||
return false
|
||||
}
|
||||
// Control frames must have FIN set
|
||||
if rsv != 0 {
|
||||
return false
|
||||
}
|
||||
// opcode must be in valid range
|
||||
if opcode > 0xA {
|
||||
return false
|
||||
}
|
||||
// payload length
|
||||
payloadLen := b1 & 0x7F
|
||||
if payloadLen > 0 && payloadLen <= 125 {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user