style: code tidy

This commit is contained in:
SunBK201 2025-11-01 00:30:57 +08:00
parent 703b10e4ab
commit f7085fd2df
3 changed files with 22 additions and 21 deletions

View File

@ -53,7 +53,7 @@ type Rewriter struct {
enablePartialReplace bool enablePartialReplace bool
uaRegex *regexp2.Regexp uaRegex *regexp2.Regexp
cache *expirable.LRU[string, string] Cache *expirable.LRU[string, string]
whitelist map[string]struct{} whitelist map[string]struct{}
} }
@ -78,7 +78,7 @@ func New(cfg *config.Config) (*Rewriter, error) {
pattern: cfg.UAPattern, pattern: cfg.UAPattern,
enablePartialReplace: cfg.EnablePartialReplace, enablePartialReplace: cfg.EnablePartialReplace,
uaRegex: uaRegex, uaRegex: uaRegex,
cache: cache, Cache: cache,
whitelist: whitelist, whitelist: whitelist,
}, nil }, nil
} }
@ -87,17 +87,9 @@ func New(cfg *config.Config) (*Rewriter, error) {
// - If target in LRU cache: pass-through (raw). // - If target in LRU cache: pass-through (raw).
// - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched). // - Else if HTTP: rewrite UA (unless whitelisted or pattern not matched).
// - Else: mark target in LRU and pass-through. // - Else: mark target in LRU and pass-through.
func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (err error) { func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string, srcAddr string) (err error) {
srcAddr := src.RemoteAddr().String()
// Fast path: known pass-through
if r.cache.Contains(destAddr) {
log.LogDebugWithAddr(src.RemoteAddr().String(), destAddr, "LRU Relay Cache Hit, pass-through")
io.Copy(dst, src)
return nil
}
reader := bufio.NewReader(src) reader := bufio.NewReader(src)
defer func() { defer func() {
if err != nil { if err != nil {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error())) log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("ProxyHTTPOrRaw: %s", err.Error()))
@ -106,7 +98,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (
}() }()
if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) { if strings.HasSuffix(destAddr, "443") && isTLSClientHello(reader) {
r.cache.Add(destAddr, destAddr) r.Cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "TLS ClientHello detected") log.LogDebugWithAddr(srcAddr, destAddr, "TLS ClientHello detected")
return return
} }
@ -116,7 +108,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (
return return
} }
if !isHTTP { if !isHTTP {
r.cache.Add(destAddr, destAddr) r.Cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
return return
} }
@ -135,7 +127,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (
if isWebSocket(h2) { if isWebSocket(h2) {
log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through") log.LogDebugWithAddr(srcAddr, destAddr, "WebSocket detected, pass-through")
} else { } else {
r.cache.Add(destAddr, destAddr) r.Cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache") log.LogDebugWithAddr(srcAddr, destAddr, "Not HTTP, added to LRU Relay Cache")
} }
return return
@ -150,7 +142,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (
// No UA header: pass-through after writing this first request // No UA header: pass-through after writing this first request
if originalUA == "" { if originalUA == "" {
r.cache.Add(destAddr, destAddr) r.Cache.Add(destAddr, destAddr)
log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache") log.LogDebugWithAddr(srcAddr, destAddr, "Not found User-Agent, Add LRU Relay Cache")
if err = req.Write(dst); err != nil { if err = req.Write(dst); err != nil {
err = fmt.Errorf("req.Write: %w", err) err = fmt.Errorf("req.Write: %w", err)
@ -175,7 +167,7 @@ func (r *Rewriter) ProxyHTTPOrRaw(dst net.Conn, src net.Conn, destAddr string) (
} }
if isWhitelist { if isWhitelist {
log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA)) log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent Whitelist: %s", originalUA))
r.cache.Add(destAddr, destAddr) r.Cache.Add(destAddr, destAddr)
} }
statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ statistics.AddPassThroughRecord(&statistics.PassThroughRecord{
Host: destAddr, Host: destAddr,

View File

@ -48,7 +48,7 @@ func TestNewRewriter(t *testing.T) {
assert.Equal(t, cfg.UAPattern, rewriter.pattern) assert.Equal(t, cfg.UAPattern, rewriter.pattern)
assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace) assert.Equal(t, cfg.EnablePartialReplace, rewriter.enablePartialReplace)
assert.NotNil(t, rewriter.uaRegex) assert.NotNil(t, rewriter.uaRegex)
assert.NotNil(t, rewriter.cache) assert.NotNil(t, rewriter.Cache)
} }
func TestIsHTTP(t *testing.T) { func TestIsHTTP(t *testing.T) {
@ -83,7 +83,7 @@ func TestProxyHTTPOrRaw_HTTPRewrite(t *testing.T) {
dstBuf := &bytes.Buffer{} dstBuf := &bytes.Buffer{}
dst := &mockConn{Reader: nil, Writer: dstBuf} dst := &mockConn{Reader: nil, Writer: dstBuf}
r.ProxyHTTPOrRaw(dst, src, "example.com:80") r.ProxyHTTPOrRaw(dst, src, "example.com:80", "srcAddr")
out := dstBuf.String() out := dstBuf.String()
assert.Contains(t, out, "User-Agent: MockUA/1.0") assert.Contains(t, out, "User-Agent: MockUA/1.0")

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/sunbk201/ua3f/internal/log"
"github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/rewrite"
) )
@ -38,7 +39,7 @@ func CopyHalf(dst, src net.Conn) {
} }
// ProxyHalf runs the rewriter proxy on src->dst and then half-closes both sides. // ProxyHalf runs the rewriter proxy on src->dst and then half-closes both sides.
func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddrPort string) { func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddr string) {
defer func() { defer func() {
if tc, ok := dst.(*net.TCPConn); ok { if tc, ok := dst.(*net.TCPConn); ok {
_ = tc.CloseWrite() _ = tc.CloseWrite()
@ -51,7 +52,15 @@ func ProxyHalf(dst, src net.Conn, rw *rewrite.Rewriter, destAddrPort string) {
_ = src.Close() _ = src.Close()
} }
}() }()
_ = rw.ProxyHTTPOrRaw(dst, src, destAddrPort)
// Fast path: known pass-through
srcAddr := src.RemoteAddr().String()
if rw.Cache.Contains(destAddr) {
log.LogDebugWithAddr(srcAddr, destAddr, "LRU Relay Cache Hit, pass-through")
io.Copy(dst, src)
return
}
_ = rw.ProxyHTTPOrRaw(dst, src, destAddr, srcAddr)
} }
func GetConnFD(conn net.Conn) (fd int, err error) { func GetConnFD(conn net.Conn) (fd int, err error) {