From 0cf1aa5e796b7ae75e1c009603cde0054f115cab Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Thu, 20 Nov 2025 13:59:42 +0800 Subject: [PATCH] fix: add mutex for thread-safe access to statistics records --- src/internal/statistics/conn.go | 8 +++++++- src/internal/statistics/pass.go | 9 ++++++++- src/internal/statistics/rewrite.go | 9 ++++++++- src/internal/statistics/statistics.go | 6 ++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/internal/statistics/conn.go b/src/internal/statistics/conn.go index bc6bf6e..a08130d 100644 --- a/src/internal/statistics/conn.go +++ b/src/internal/statistics/conn.go @@ -5,6 +5,7 @@ import ( "log/slog" "os" "sort" + "sync" "time" "github.com/sunbk201/ua3f/internal/sniff" @@ -25,7 +26,10 @@ type ConnectionAction struct { Record ConnectionRecord } -var connectionRecords = make(map[string]*ConnectionRecord) +var ( + connectionRecords = make(map[string]*ConnectionRecord) + connectionRecordsMu sync.RWMutex +) // AddConnection adds or updates a connection record func AddConnection(record *ConnectionRecord) { @@ -62,10 +66,12 @@ func dumpConnectionRecords() { } }() + connectionRecordsMu.RLock() var statList []ConnectionRecord for _, record := range connectionRecords { statList = append(statList, *record) } + connectionRecordsMu.RUnlock() // Sort by start time (newest first) sort.SliceStable(statList, func(i, j int) bool { diff --git a/src/internal/statistics/pass.go b/src/internal/statistics/pass.go index a043501..f4e4713 100644 --- a/src/internal/statistics/pass.go +++ b/src/internal/statistics/pass.go @@ -5,6 +5,7 @@ import ( "log/slog" "os" "sort" + "sync" ) const passthroughStatsFile = "/var/log/ua3f/pass_stats" @@ -16,7 +17,10 @@ type PassThroughRecord struct { Count int } -var passThroughRecords = make(map[string]*PassThroughRecord) +var ( + passThroughRecords = make(map[string]*PassThroughRecord) + passThroughRecordsMu sync.RWMutex +) func AddPassThroughRecord(record *PassThroughRecord) { select { @@ -37,10 +41,13 @@ func dumpPassThroughRecords() { } }() + passThroughRecordsMu.RLock() var statList []PassThroughRecord for _, record := range passThroughRecords { statList = append(statList, *record) } + passThroughRecordsMu.RUnlock() + sort.SliceStable(statList, func(i, j int) bool { return statList[i].Count > statList[j].Count }) diff --git a/src/internal/statistics/rewrite.go b/src/internal/statistics/rewrite.go index d44486d..50c17f8 100644 --- a/src/internal/statistics/rewrite.go +++ b/src/internal/statistics/rewrite.go @@ -5,6 +5,7 @@ import ( "log/slog" "os" "sort" + "sync" ) const rewriteStatsFile = "/var/log/ua3f/rewrite_stats" @@ -16,7 +17,10 @@ type RewriteRecord struct { MockedUA string } -var rewriteRecords = make(map[string]*RewriteRecord) +var ( + rewriteRecords = make(map[string]*RewriteRecord) + rewriteRecordsMu sync.RWMutex +) func AddRewriteRecord(record *RewriteRecord) { select { @@ -37,10 +41,13 @@ func dumpRewriteRecords() { } }() + rewriteRecordsMu.RLock() var statList []RewriteRecord for _, record := range rewriteRecords { statList = append(statList, *record) } + rewriteRecordsMu.RUnlock() + sort.SliceStable(statList, func(i, j int) bool { return statList[i].Count > statList[j].Count }) diff --git a/src/internal/statistics/statistics.go b/src/internal/statistics/statistics.go index 1beabfc..58721af 100644 --- a/src/internal/statistics/statistics.go +++ b/src/internal/statistics/statistics.go @@ -26,6 +26,7 @@ func StartRecorder() { for { select { case record := <-rewriteRecordChan: + rewriteRecordsMu.Lock() if r, exists := rewriteRecords[record.Host]; exists { r.Count++ r.OriginalUA = record.OriginalUA @@ -38,10 +39,12 @@ func StartRecorder() { MockedUA: record.MockedUA, } } + rewriteRecordsMu.Unlock() case record := <-passThroughRecordChan: if strings.HasPrefix(record.UA, "curl/") { record.UA = "curl/*" } + passThroughRecordsMu.Lock() if r, exists := passThroughRecords[record.UA]; exists { r.Count++ r.DestAddr = record.DestAddr @@ -54,7 +57,9 @@ func StartRecorder() { Count: 1, } } + passThroughRecordsMu.Unlock() case action := <-connectionActionChan: + connectionRecordsMu.Lock() switch action.Action { case Add: if r, exists := connectionRecords[action.Key]; exists { @@ -70,6 +75,7 @@ func StartRecorder() { case Remove: delete(connectionRecords, action.Key) } + connectionRecordsMu.Unlock() case <-ticker.C: dumpRewriteRecords() dumpPassThroughRecords()