fix: add mutex for thread-safe access to statistics records

This commit is contained in:
SunBK201 2025-11-20 13:59:42 +08:00
parent 37107ac08d
commit 0cf1aa5e79
4 changed files with 29 additions and 3 deletions

View File

@ -5,6 +5,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"sort" "sort"
"sync"
"time" "time"
"github.com/sunbk201/ua3f/internal/sniff" "github.com/sunbk201/ua3f/internal/sniff"
@ -25,7 +26,10 @@ type ConnectionAction struct {
Record ConnectionRecord Record ConnectionRecord
} }
var connectionRecords = make(map[string]*ConnectionRecord) var (
connectionRecords = make(map[string]*ConnectionRecord)
connectionRecordsMu sync.RWMutex
)
// AddConnection adds or updates a connection record // AddConnection adds or updates a connection record
func AddConnection(record *ConnectionRecord) { func AddConnection(record *ConnectionRecord) {
@ -62,10 +66,12 @@ func dumpConnectionRecords() {
} }
}() }()
connectionRecordsMu.RLock()
var statList []ConnectionRecord var statList []ConnectionRecord
for _, record := range connectionRecords { for _, record := range connectionRecords {
statList = append(statList, *record) statList = append(statList, *record)
} }
connectionRecordsMu.RUnlock()
// Sort by start time (newest first) // Sort by start time (newest first)
sort.SliceStable(statList, func(i, j int) bool { sort.SliceStable(statList, func(i, j int) bool {

View File

@ -5,6 +5,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"sort" "sort"
"sync"
) )
const passthroughStatsFile = "/var/log/ua3f/pass_stats" const passthroughStatsFile = "/var/log/ua3f/pass_stats"
@ -16,7 +17,10 @@ type PassThroughRecord struct {
Count int Count int
} }
var passThroughRecords = make(map[string]*PassThroughRecord) var (
passThroughRecords = make(map[string]*PassThroughRecord)
passThroughRecordsMu sync.RWMutex
)
func AddPassThroughRecord(record *PassThroughRecord) { func AddPassThroughRecord(record *PassThroughRecord) {
select { select {
@ -37,10 +41,13 @@ func dumpPassThroughRecords() {
} }
}() }()
passThroughRecordsMu.RLock()
var statList []PassThroughRecord var statList []PassThroughRecord
for _, record := range passThroughRecords { for _, record := range passThroughRecords {
statList = append(statList, *record) statList = append(statList, *record)
} }
passThroughRecordsMu.RUnlock()
sort.SliceStable(statList, func(i, j int) bool { sort.SliceStable(statList, func(i, j int) bool {
return statList[i].Count > statList[j].Count return statList[i].Count > statList[j].Count
}) })

View File

@ -5,6 +5,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"sort" "sort"
"sync"
) )
const rewriteStatsFile = "/var/log/ua3f/rewrite_stats" const rewriteStatsFile = "/var/log/ua3f/rewrite_stats"
@ -16,7 +17,10 @@ type RewriteRecord struct {
MockedUA string MockedUA string
} }
var rewriteRecords = make(map[string]*RewriteRecord) var (
rewriteRecords = make(map[string]*RewriteRecord)
rewriteRecordsMu sync.RWMutex
)
func AddRewriteRecord(record *RewriteRecord) { func AddRewriteRecord(record *RewriteRecord) {
select { select {
@ -37,10 +41,13 @@ func dumpRewriteRecords() {
} }
}() }()
rewriteRecordsMu.RLock()
var statList []RewriteRecord var statList []RewriteRecord
for _, record := range rewriteRecords { for _, record := range rewriteRecords {
statList = append(statList, *record) statList = append(statList, *record)
} }
rewriteRecordsMu.RUnlock()
sort.SliceStable(statList, func(i, j int) bool { sort.SliceStable(statList, func(i, j int) bool {
return statList[i].Count > statList[j].Count return statList[i].Count > statList[j].Count
}) })

View File

@ -26,6 +26,7 @@ func StartRecorder() {
for { for {
select { select {
case record := <-rewriteRecordChan: case record := <-rewriteRecordChan:
rewriteRecordsMu.Lock()
if r, exists := rewriteRecords[record.Host]; exists { if r, exists := rewriteRecords[record.Host]; exists {
r.Count++ r.Count++
r.OriginalUA = record.OriginalUA r.OriginalUA = record.OriginalUA
@ -38,10 +39,12 @@ func StartRecorder() {
MockedUA: record.MockedUA, MockedUA: record.MockedUA,
} }
} }
rewriteRecordsMu.Unlock()
case record := <-passThroughRecordChan: case record := <-passThroughRecordChan:
if strings.HasPrefix(record.UA, "curl/") { if strings.HasPrefix(record.UA, "curl/") {
record.UA = "curl/*" record.UA = "curl/*"
} }
passThroughRecordsMu.Lock()
if r, exists := passThroughRecords[record.UA]; exists { if r, exists := passThroughRecords[record.UA]; exists {
r.Count++ r.Count++
r.DestAddr = record.DestAddr r.DestAddr = record.DestAddr
@ -54,7 +57,9 @@ func StartRecorder() {
Count: 1, Count: 1,
} }
} }
passThroughRecordsMu.Unlock()
case action := <-connectionActionChan: case action := <-connectionActionChan:
connectionRecordsMu.Lock()
switch action.Action { switch action.Action {
case Add: case Add:
if r, exists := connectionRecords[action.Key]; exists { if r, exists := connectionRecords[action.Key]; exists {
@ -70,6 +75,7 @@ func StartRecorder() {
case Remove: case Remove:
delete(connectionRecords, action.Key) delete(connectionRecords, action.Key)
} }
connectionRecordsMu.Unlock()
case <-ticker.C: case <-ticker.C:
dumpRewriteRecords() dumpRewriteRecords()
dumpPassThroughRecords() dumpPassThroughRecords()