From bef94bdc9d5fe5e17aa20e5379170cec5da7d12d Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Sat, 6 Dec 2025 16:00:03 +0800 Subject: [PATCH] refactor: refactor statistics recording --- src/internal/rewrite/packet.go | 4 +- src/internal/rewrite/rewriter.go | 15 ++- src/internal/server/base/server.go | 18 +-- src/internal/server/http/http.go | 12 +- src/internal/server/nfqueue/nfqueue_linux.go | 5 +- src/internal/server/nfqueue/nfqueue_others.go | 4 +- .../server/redirect/redirect_linux.go | 6 +- .../server/redirect/redirect_others.go | 4 +- src/internal/server/server.go | 22 +++- src/internal/server/socks5/socks5.go | 6 +- src/internal/server/tproxy/tproxy_linux.go | 6 +- src/internal/server/tproxy/tproxy_others.go | 4 +- src/internal/statistics/conn.go | 93 +++++++++----- src/internal/statistics/pass.go | 72 ++++++++--- src/internal/statistics/rewrite.go | 67 +++++++--- src/internal/statistics/statistics.go | 118 +++++++----------- src/main.go | 11 +- 17 files changed, 284 insertions(+), 183 deletions(-) diff --git a/src/internal/rewrite/packet.go b/src/internal/rewrite/packet.go index dac437a..c69b0aa 100644 --- a/src/internal/rewrite/packet.go +++ b/src/internal/rewrite/packet.go @@ -43,7 +43,7 @@ func (r *Rewriter) buildReplacement(srcAddr, dstAddr string, originalUA string, newUA := r.buildUserAgent(originalUA) log.LogInfoWithAddr(srcAddr, dstAddr, fmt.Sprintf("Rewritten User-Agent: %s", newUA)) - statistics.AddRewriteRecord(&statistics.RewriteRecord{ + r.Recorder.AddRecord(&statistics.RewriteRecord{ Host: dstAddr, OriginalUA: originalUA, MockedUA: newUA, @@ -95,7 +95,7 @@ func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr strin // Check if should rewrite if !r.shouldRewriteUA(srcAddr, dstAddr, originalUA) { - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + r.Recorder.AddRecord(&statistics.PassThroughRecord{ SrcAddr: srcAddr, DestAddr: dstAddr, UA: originalUA, diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go index 4db95c7..3bef437 100644 --- a/src/internal/rewrite/rewriter.go +++ b/src/internal/rewrite/rewriter.go @@ -23,6 +23,8 @@ type Rewriter struct { uaRegex *regexp2.Regexp ruleEngine *rule.Engine whitelist []string + + Recorder *statistics.Recorder } type RewriteDecision struct { @@ -41,7 +43,7 @@ func (d *RewriteDecision) ShouldRewrite() bool { } // New constructs a Rewriter from config. Compiles regex and allocates cache. -func New(cfg *config.Config) (*Rewriter, error) { +func New(cfg *config.Config, recorder *statistics.Recorder) (*Rewriter, error) { // UA pattern is compiled with case-insensitive prefix (?i) pattern := "(?i)" + cfg.UARegex uaRegex, err := regexp2.Compile(pattern, regexp2.None) @@ -71,6 +73,7 @@ func New(cfg *config.Config) (*Rewriter, error) { "Go-http-client/1.1", "ByteDancePcdn", }, + Recorder: recorder, }, nil } @@ -110,7 +113,7 @@ func (r *Rewriter) EvaluateRewriteDecision(req *http.Request, srcAddr, destAddr // DIRECT if r.rewriteMode == config.RewriteModeDirect { log.LogDebugWithAddr(srcAddr, destAddr, "Direct forward mode, skip rewriting") - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + r.Recorder.AddRecord(&statistics.PassThroughRecord{ SrcAddr: srcAddr, DestAddr: destAddr, UA: originalUA, @@ -127,7 +130,7 @@ func (r *Rewriter) EvaluateRewriteDecision(req *http.Request, srcAddr, destAddr // no match rule, direct forward if matchedRule == nil { log.LogDebugWithAddr(srcAddr, destAddr, "No rule matched, direct forward") - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + r.Recorder.AddRecord(&statistics.PassThroughRecord{ SrcAddr: srcAddr, DestAddr: destAddr, UA: originalUA, @@ -149,7 +152,7 @@ func (r *Rewriter) EvaluateRewriteDecision(req *http.Request, srcAddr, destAddr // DIRECT if matchedRule.Action == rule.ActionDirect { log.LogDebugWithAddr(srcAddr, destAddr, "Rule matched: DIRECT action, skip rewriting") - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + r.Recorder.AddRecord(&statistics.PassThroughRecord{ SrcAddr: srcAddr, DestAddr: destAddr, UA: originalUA, @@ -195,7 +198,7 @@ func (r *Rewriter) EvaluateRewriteDecision(req *http.Request, srcAddr, destAddr hit := !isWhitelist && matches if !hit { - statistics.AddPassThroughRecord(&statistics.PassThroughRecord{ + r.Recorder.AddRecord(&statistics.PassThroughRecord{ SrcAddr: srcAddr, DestAddr: destAddr, UA: originalUA, @@ -233,7 +236,7 @@ func (r *Rewriter) Rewrite(req *http.Request, srcAddr string, destAddr string, d log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Rewrite %s from (%s) to (%s)", headerName, originalValue, rewritedValue)) - statistics.AddRewriteRecord(&statistics.RewriteRecord{ + r.Recorder.AddRecord(&statistics.RewriteRecord{ Host: destAddr, OriginalUA: originalValue, MockedUA: rewritedValue, diff --git a/src/internal/server/base/server.go b/src/internal/server/base/server.go index 5cf64ce..de88a59 100644 --- a/src/internal/server/base/server.go +++ b/src/internal/server/base/server.go @@ -20,18 +20,20 @@ import ( type Server struct { Cfg *config.Config Rewriter *rewrite.Rewriter + Recorder *statistics.Recorder Cache *expirable.LRU[string, struct{}] } func (s *Server) ServeConnLink(connLink *ConnLink) { slog.Info(fmt.Sprintf("New connection link: %s <-> %s", connLink.LAddr, connLink.RAddr), "ConnLink", connLink) - statistics.AddConnection(&statistics.ConnectionRecord{ + record := &statistics.ConnectionRecord{ Protocol: sniff.TCP, SrcAddr: connLink.LAddr, DestAddr: connLink.RAddr, StartTime: time.Now(), - }) - defer statistics.RemoveConnection(connLink.LAddr, connLink.RAddr) + } + s.Recorder.AddRecord(record) + defer s.Recorder.RemoveRecord(record) defer slog.Info(fmt.Sprintf("Connection link closed: %s <-> %s", connLink.LAddr, connLink.RAddr), "ConnLink", connLink) go connLink.CopyRL() @@ -60,7 +62,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { if isTLS, _ := sniff.SniffTLS(reader); isTLS { s.Cache.Add(c.RAddr, struct{}{}) c.LogInfo("TLS client hello detected") - statistics.AddConnection(&statistics.ConnectionRecord{ + s.Recorder.AddRecord(&statistics.ConnectionRecord{ Protocol: sniff.HTTPS, SrcAddr: c.LAddr, DestAddr: c.RAddr, @@ -79,7 +81,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { s.Cache.Add(c.RAddr, struct{}{}) c.LogInfo("Sniff first request is not http, switch to direct forward") if isTLS, _ := sniff.SniffTLS(reader); isTLS { - statistics.AddConnection(&statistics.ConnectionRecord{ + s.Recorder.AddRecord(&statistics.ConnectionRecord{ Protocol: sniff.TLS, SrcAddr: c.LAddr, DestAddr: c.RAddr, @@ -88,7 +90,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { return } - statistics.AddConnection(&statistics.ConnectionRecord{ + s.Recorder.AddRecord(&statistics.ConnectionRecord{ Protocol: sniff.HTTP, SrcAddr: c.LAddr, DestAddr: c.RAddr, @@ -99,7 +101,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { for { if isHTTP, err = sniff.SniffHTTPFast(reader); err != nil { err = fmt.Errorf("sniff.SniffHTTPFast: %w", err) - statistics.AddConnection( + s.Recorder.AddRecord( &statistics.ConnectionRecord{ Protocol: sniff.TCP, SrcAddr: c.LAddr, @@ -137,7 +139,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { if req.Header.Get("Upgrade") == "websocket" && req.Header.Get("Connection") == "Upgrade" { c.LogInfo("websocket upgrade detected, switch to direct forward") - statistics.AddConnection(&statistics.ConnectionRecord{ + s.Recorder.AddRecord(&statistics.ConnectionRecord{ Protocol: sniff.WebSocket, SrcAddr: c.LAddr, DestAddr: c.RAddr, diff --git a/src/internal/server/http/http.go b/src/internal/server/http/http.go index 3ce2390..059021a 100644 --- a/src/internal/server/http/http.go +++ b/src/internal/server/http/http.go @@ -21,17 +21,19 @@ type Server struct { base.Server } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { return &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, } } func (s *Server) Start() (err error) { + s.Recorder.Start() server := &http.Server{ Addr: s.Cfg.ListenAddr, Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -60,13 +62,15 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { destPort = "80" } destAddr := fmt.Sprintf("%s:%s", req.URL.Hostname(), destPort) - statistics.AddConnection(&statistics.ConnectionRecord{ + + record := &statistics.ConnectionRecord{ Protocol: sniff.HTTP, SrcAddr: req.RemoteAddr, DestAddr: destAddr, StartTime: time.Now(), - }) - defer statistics.RemoveConnection(req.RemoteAddr, destAddr) + } + s.Recorder.AddRecord(record) + defer s.Recorder.RemoveRecord(record) slog.Info("HTTP proxy request", slog.String("srcAddr", req.RemoteAddr), slog.String("destAddr", destAddr)) diff --git a/src/internal/server/nfqueue/nfqueue_linux.go b/src/internal/server/nfqueue/nfqueue_linux.go index ea56186..42d07a6 100644 --- a/src/internal/server/nfqueue/nfqueue_linux.go +++ b/src/internal/server/nfqueue/nfqueue_linux.go @@ -16,6 +16,7 @@ import ( "github.com/sunbk201/ua3f/internal/netfilter" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { @@ -28,11 +29,12 @@ type Server struct { NotHTTPCtMark uint32 } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, SniffCtMarkLower: 10201, @@ -63,6 +65,7 @@ func (s *Server) Start() (err error) { slog.Error("s.Firewall.Setup", slog.Any("error", err)) return err } + s.Recorder.Start() return s.nfqServer.Start() } diff --git a/src/internal/server/nfqueue/nfqueue_others.go b/src/internal/server/nfqueue/nfqueue_others.go index 23a01e5..ce9b622 100644 --- a/src/internal/server/nfqueue/nfqueue_others.go +++ b/src/internal/server/nfqueue/nfqueue_others.go @@ -8,17 +8,19 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { base.Server } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, }, } return s diff --git a/src/internal/server/redirect/redirect_linux.go b/src/internal/server/redirect/redirect_linux.go index a084b9e..70b24da 100644 --- a/src/internal/server/redirect/redirect_linux.go +++ b/src/internal/server/redirect/redirect_linux.go @@ -15,6 +15,7 @@ import ( "github.com/sunbk201/ua3f/internal/netfilter" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" "sigs.k8s.io/knftables" ) @@ -25,11 +26,12 @@ type Server struct { so_mark int } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, so_mark: netfilter.SO_MARK, @@ -57,6 +59,8 @@ func (s *Server) Start() (err error) { return fmt.Errorf("net.Listen: %w", err) } + s.Recorder.Start() + go func() { var client net.Conn for { diff --git a/src/internal/server/redirect/redirect_others.go b/src/internal/server/redirect/redirect_others.go index ae23387..7971dd4 100644 --- a/src/internal/server/redirect/redirect_others.go +++ b/src/internal/server/redirect/redirect_others.go @@ -9,17 +9,19 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { base.Server } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { return &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, }, } } diff --git a/src/internal/server/server.go b/src/internal/server/server.go index ff11586..9398c43 100644 --- a/src/internal/server/server.go +++ b/src/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "log/slog" "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" @@ -10,6 +11,7 @@ import ( "github.com/sunbk201/ua3f/internal/server/redirect" "github.com/sunbk201/ua3f/internal/server/socks5" "github.com/sunbk201/ua3f/internal/server/tproxy" + "github.com/sunbk201/ua3f/internal/statistics" ) type ServerMode string @@ -27,18 +29,26 @@ type Server interface { Close() error } -func NewServer(cfg *config.Config, rw *rewrite.Rewriter) (Server, error) { +func NewServer(cfg *config.Config) (Server, error) { + rc := statistics.New() + + rw, err := rewrite.New(cfg, rc) + if err != nil { + slog.Error("rewrite.New", slog.Any("error", err)) + return nil, err + } + switch cfg.ServerMode { case config.ServerModeHTTP: - return http.New(cfg, rw), nil + return http.New(cfg, rw, rc), nil case config.ServerModeSocks5: - return socks5.New(cfg, rw), nil + return socks5.New(cfg, rw, rc), nil case config.ServerModeTProxy: - return tproxy.New(cfg, rw), nil + return tproxy.New(cfg, rw, rc), nil case config.ServerModeRedirect: - return redirect.New(cfg, rw), nil + return redirect.New(cfg, rw, rc), nil case config.ServerModeNFQueue: - return nfqueue.New(cfg, rw), nil + return nfqueue.New(cfg, rw, rc), nil default: return nil, fmt.Errorf("NewServer unknown server mode: %s", cfg.ServerMode) } diff --git a/src/internal/server/socks5/socks5.go b/src/internal/server/socks5/socks5.go index 3e6de76..3f1b5c4 100644 --- a/src/internal/server/socks5/socks5.go +++ b/src/internal/server/socks5/socks5.go @@ -14,6 +14,7 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { @@ -21,11 +22,12 @@ type Server struct { listener net.Listener } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { return &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, } @@ -43,6 +45,8 @@ func (s *Server) Start() (err error) { return fmt.Errorf("net.Listen: %w", err) } + s.Recorder.Start() + go func() { var client net.Conn for { diff --git a/src/internal/server/tproxy/tproxy_linux.go b/src/internal/server/tproxy/tproxy_linux.go index edfe77d..871bd90 100644 --- a/src/internal/server/tproxy/tproxy_linux.go +++ b/src/internal/server/tproxy/tproxy_linux.go @@ -19,6 +19,7 @@ import ( "github.com/sunbk201/ua3f/internal/netfilter" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { @@ -31,11 +32,12 @@ type Server struct { ignoreMark []string } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), }, so_mark: netfilter.SO_MARK, @@ -88,6 +90,8 @@ func (s *Server) Start() error { return fmt.Errorf("net.Listen: %w", err) } + s.Recorder.Start() + go func() { var client net.Conn for { diff --git a/src/internal/server/tproxy/tproxy_others.go b/src/internal/server/tproxy/tproxy_others.go index 917bf7e..a85ff2b 100644 --- a/src/internal/server/tproxy/tproxy_others.go +++ b/src/internal/server/tproxy/tproxy_others.go @@ -9,17 +9,19 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server/base" + "github.com/sunbk201/ua3f/internal/statistics" ) type Server struct { base.Server } -func New(cfg *config.Config, rw *rewrite.Rewriter) *Server { +func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { return &Server{ Server: base.Server{ Cfg: cfg, Rewriter: rw, + Recorder: rc, }, } } diff --git a/src/internal/statistics/conn.go b/src/internal/statistics/conn.go index a08130d..0a21321 100644 --- a/src/internal/statistics/conn.go +++ b/src/internal/statistics/conn.go @@ -11,7 +11,13 @@ import ( "github.com/sunbk201/ua3f/internal/sniff" ) -const connStatsFile = "/var/log/ua3f/conn_stats" +type ConnectionRecordList struct { + recordAddChan chan *ConnectionRecord + recordRemoveChan chan *ConnectionRecord + records map[string]*ConnectionRecord + mu sync.RWMutex + dumpFile string +} type ConnectionRecord struct { Protocol sniff.Protocol @@ -20,42 +26,61 @@ type ConnectionRecord struct { StartTime time.Time } -type ConnectionAction struct { - Action Action - Key string - Record ConnectionRecord -} - -var ( - connectionRecords = make(map[string]*ConnectionRecord) - connectionRecordsMu sync.RWMutex -) - -// AddConnection adds or updates a connection record -func AddConnection(record *ConnectionRecord) { - select { - case connectionActionChan <- ConnectionAction{ - Action: Add, - Key: fmt.Sprintf("%s-%s", record.SrcAddr, record.DestAddr), - Record: *record, - }: - default: +func NewConnectionRecordList(dumpFile string) *ConnectionRecordList { + return &ConnectionRecordList{ + recordAddChan: make(chan *ConnectionRecord, 500), + recordRemoveChan: make(chan *ConnectionRecord, 500), + records: make(map[string]*ConnectionRecord, 500), + mu: sync.RWMutex{}, + dumpFile: dumpFile, } } -// RemoveConnection removes a connection record -func RemoveConnection(srcAddr, destAddr string) { - select { - case connectionActionChan <- ConnectionAction{ - Action: Remove, - Key: fmt.Sprintf("%s-%s", srcAddr, destAddr), - }: - default: +func (l *ConnectionRecordList) Run() { + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case record := <-l.recordAddChan: + l.Add(record) + case record := <-l.recordRemoveChan: + l.Remove(record) + case <-ticker.C: + l.Dump() + } + } + }() +} + +func (l *ConnectionRecordList) Add(record *ConnectionRecord) { + l.mu.Lock() + defer l.mu.Unlock() + + key := fmt.Sprintf("%s-%s", record.SrcAddr, record.DestAddr) + if r, exists := l.records[key]; exists { + r.Protocol = record.Protocol + } else { + l.records[key] = &ConnectionRecord{ + Protocol: record.Protocol, + SrcAddr: record.SrcAddr, + DestAddr: record.DestAddr, + StartTime: record.StartTime, + } } } -func dumpConnectionRecords() { - f, err := os.Create(connStatsFile) +func (l *ConnectionRecordList) Remove(record *ConnectionRecord) { + l.mu.Lock() + defer l.mu.Unlock() + + key := fmt.Sprintf("%s-%s", record.SrcAddr, record.DestAddr) + delete(l.records, key) +} + +func (l *ConnectionRecordList) Dump() { + f, err := os.Create(l.dumpFile) if err != nil { slog.Error("os.Create", slog.Any("error", err)) return @@ -66,12 +91,12 @@ func dumpConnectionRecords() { } }() - connectionRecordsMu.RLock() + l.mu.RLock() var statList []ConnectionRecord - for _, record := range connectionRecords { + for _, record := range l.records { statList = append(statList, *record) } - connectionRecordsMu.RUnlock() + l.mu.RUnlock() // Sort by start time (newest first) sort.SliceStable(statList, func(i, j int) bool { diff --git a/src/internal/statistics/pass.go b/src/internal/statistics/pass.go index f4e4713..e49aa20 100644 --- a/src/internal/statistics/pass.go +++ b/src/internal/statistics/pass.go @@ -5,10 +5,17 @@ import ( "log/slog" "os" "sort" + "strings" "sync" + "time" ) -const passthroughStatsFile = "/var/log/ua3f/pass_stats" +type PassThroughRecordList struct { + recordAddChan chan *PassThroughRecord + records map[string]*PassThroughRecord + mu sync.RWMutex + dumpFile string +} type PassThroughRecord struct { SrcAddr string @@ -17,20 +24,55 @@ type PassThroughRecord struct { Count int } -var ( - passThroughRecords = make(map[string]*PassThroughRecord) - passThroughRecordsMu sync.RWMutex -) - -func AddPassThroughRecord(record *PassThroughRecord) { - select { - case passThroughRecordChan <- *record: - default: +func NewPassThroughRecordList(dumpFile string) *PassThroughRecordList { + return &PassThroughRecordList{ + recordAddChan: make(chan *PassThroughRecord, 500), + records: make(map[string]*PassThroughRecord, 500), + mu: sync.RWMutex{}, + dumpFile: dumpFile, } } -func dumpPassThroughRecords() { - f, err := os.Create(passthroughStatsFile) +func (l *PassThroughRecordList) Run() { + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case record := <-l.recordAddChan: + l.Add(record) + case <-ticker.C: + l.Dump() + } + } + }() +} + +func (l *PassThroughRecordList) Add(record *PassThroughRecord) { + if strings.HasPrefix(record.UA, "curl/") { + record.UA = "curl/*" + } + + l.mu.Lock() + defer l.mu.Unlock() + + if r, exists := l.records[record.UA]; exists { + r.Count++ + r.SrcAddr = record.SrcAddr + r.DestAddr = record.DestAddr + } else { + l.records[record.UA] = &PassThroughRecord{ + SrcAddr: record.SrcAddr, + DestAddr: record.DestAddr, + UA: record.UA, + Count: 1, + } + } +} + +func (l *PassThroughRecordList) Dump() { + f, err := os.Create(l.dumpFile) if err != nil { slog.Error("os.Create", slog.Any("error", err)) return @@ -41,12 +83,12 @@ func dumpPassThroughRecords() { } }() - passThroughRecordsMu.RLock() + l.mu.RLock() var statList []PassThroughRecord - for _, record := range passThroughRecords { + for _, record := range l.records { statList = append(statList, *record) } - passThroughRecordsMu.RUnlock() + l.mu.RUnlock() sort.SliceStable(statList, func(i, j int) bool { return statList[i].Count > statList[j].Count diff --git a/src/internal/statistics/rewrite.go b/src/internal/statistics/rewrite.go index 50c17f8..879d128 100644 --- a/src/internal/statistics/rewrite.go +++ b/src/internal/statistics/rewrite.go @@ -6,9 +6,15 @@ import ( "os" "sort" "sync" + "time" ) -const rewriteStatsFile = "/var/log/ua3f/rewrite_stats" +type RewriteRecordList struct { + recordAddChan chan *RewriteRecord + records map[string]*RewriteRecord + mu sync.RWMutex + dumpFile string +} type RewriteRecord struct { Host string @@ -17,20 +23,51 @@ type RewriteRecord struct { MockedUA string } -var ( - rewriteRecords = make(map[string]*RewriteRecord) - rewriteRecordsMu sync.RWMutex -) - -func AddRewriteRecord(record *RewriteRecord) { - select { - case rewriteRecordChan <- *record: - default: +func NewRewriteRecordList(dumpFile string) *RewriteRecordList { + return &RewriteRecordList{ + recordAddChan: make(chan *RewriteRecord, 500), + records: make(map[string]*RewriteRecord, 500), + mu: sync.RWMutex{}, + dumpFile: dumpFile, } } -func dumpRewriteRecords() { - f, err := os.Create(rewriteStatsFile) +func (l *RewriteRecordList) Run() { + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case record := <-l.recordAddChan: + l.Add(record) + case <-ticker.C: + l.Dump() + } + } + }() +} + +func (l *RewriteRecordList) Add(record *RewriteRecord) { + l.mu.Lock() + defer l.mu.Unlock() + + if r, exists := l.records[record.Host]; exists { + r.Count++ + r.OriginalUA = record.OriginalUA + r.MockedUA = record.MockedUA + } else { + l.records[record.Host] = &RewriteRecord{ + Host: record.Host, + Count: 1, + OriginalUA: record.OriginalUA, + MockedUA: record.MockedUA, + } + } +} + +func (l *RewriteRecordList) Dump() { + f, err := os.Create(l.dumpFile) if err != nil { slog.Error("os.Create", slog.Any("error", err)) return @@ -41,12 +78,12 @@ func dumpRewriteRecords() { } }() - rewriteRecordsMu.RLock() + l.mu.RLock() var statList []RewriteRecord - for _, record := range rewriteRecords { + for _, record := range l.records { statList = append(statList, *record) } - rewriteRecordsMu.RUnlock() + l.mu.RUnlock() sort.SliceStable(statList, func(i, j int) bool { return statList[i].Count > statList[j].Count diff --git a/src/internal/statistics/statistics.go b/src/internal/statistics/statistics.go index 58721af..ffdcc5a 100644 --- a/src/internal/statistics/statistics.go +++ b/src/internal/statistics/statistics.go @@ -1,85 +1,51 @@ package statistics -import ( - "strings" - "time" -) +type Recorder struct { + RewriteRecordList *RewriteRecordList + PassThroughRecordList *PassThroughRecordList + ConnectionRecordList *ConnectionRecordList +} -var ( - rewriteRecordChan = make(chan RewriteRecord, 2000) - passThroughRecordChan = make(chan PassThroughRecord, 2000) - connectionActionChan = make(chan ConnectionAction, 2000) -) +func New() *Recorder { + return &Recorder{ + RewriteRecordList: NewRewriteRecordList("/var/log/ua3f/rewrite_stats"), + PassThroughRecordList: NewPassThroughRecordList("/var/log/ua3f/pass_stats"), + ConnectionRecordList: NewConnectionRecordList("/var/log/ua3f/conn_stats"), + } +} -// Actions for recording connection statistics -type Action int +func (r *Recorder) Start() { + r.RewriteRecordList.Run() + r.PassThroughRecordList.Run() + r.ConnectionRecordList.Run() +} -const ( - Add Action = iota - Remove -) - -func StartRecorder() { - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - for { +func (r *Recorder) AddRecord(record any) { + switch rec := record.(type) { + case *RewriteRecord: select { - case record := <-rewriteRecordChan: - rewriteRecordsMu.Lock() - if r, exists := rewriteRecords[record.Host]; exists { - r.Count++ - r.OriginalUA = record.OriginalUA - r.MockedUA = record.MockedUA - } else { - rewriteRecords[record.Host] = &RewriteRecord{ - Host: record.Host, - Count: 1, - OriginalUA: record.OriginalUA, - MockedUA: record.MockedUA, - } - } - rewriteRecordsMu.Unlock() - case record := <-passThroughRecordChan: - if strings.HasPrefix(record.UA, "curl/") { - record.UA = "curl/*" - } - passThroughRecordsMu.Lock() - if r, exists := passThroughRecords[record.UA]; exists { - r.Count++ - r.DestAddr = record.DestAddr - r.SrcAddr = record.SrcAddr - } else { - passThroughRecords[record.UA] = &PassThroughRecord{ - SrcAddr: record.SrcAddr, - DestAddr: record.DestAddr, - UA: record.UA, - Count: 1, - } - } - passThroughRecordsMu.Unlock() - case action := <-connectionActionChan: - connectionRecordsMu.Lock() - switch action.Action { - case Add: - if r, exists := connectionRecords[action.Key]; exists { - r.Protocol = action.Record.Protocol - } else { - connectionRecords[action.Key] = &ConnectionRecord{ - Protocol: action.Record.Protocol, - SrcAddr: action.Record.SrcAddr, - DestAddr: action.Record.DestAddr, - StartTime: action.Record.StartTime, - } - } - case Remove: - delete(connectionRecords, action.Key) - } - connectionRecordsMu.Unlock() - case <-ticker.C: - dumpRewriteRecords() - dumpPassThroughRecords() - dumpConnectionRecords() + case r.RewriteRecordList.recordAddChan <- rec: + default: + } + case *PassThroughRecord: + select { + case r.PassThroughRecordList.recordAddChan <- rec: + default: + } + case *ConnectionRecord: + select { + case r.ConnectionRecordList.recordAddChan <- rec: + default: + } + } +} + +func (r *Recorder) RemoveRecord(record any) { + switch rec := record.(type) { + case *ConnectionRecord: + select { + case r.ConnectionRecordList.recordRemoveChan <- rec: + default: } } } diff --git a/src/main.go b/src/main.go index ed6c0af..950765d 100644 --- a/src/main.go +++ b/src/main.go @@ -9,11 +9,9 @@ import ( "github.com/sunbk201/ua3f/internal/config" "github.com/sunbk201/ua3f/internal/log" - "github.com/sunbk201/ua3f/internal/rewrite" "github.com/sunbk201/ua3f/internal/server" "github.com/sunbk201/ua3f/internal/server/desync" "github.com/sunbk201/ua3f/internal/server/netlink" - "github.com/sunbk201/ua3f/internal/statistics" "github.com/sunbk201/ua3f/internal/usergroup" ) @@ -39,12 +37,6 @@ func main() { return } - rw, err := rewrite.New(cfg) - if err != nil { - slog.Error("rewrite.New", slog.Any("error", err)) - return - } - helper := netlink.New(cfg) addShutdown("helper.Close", helper.Close) if err := helper.Start(); err != nil { @@ -63,13 +55,12 @@ func main() { } } - srv, err := server.NewServer(cfg, rw) + srv, err := server.NewServer(cfg) if err != nil { slog.Error("server.NewServer", slog.Any("error", err)) shutdown() return } - go statistics.StartRecorder() addShutdown("srv.Close", srv.Close) if err := srv.Start(); err != nil { slog.Error("srv.Start", slog.Any("error", err))