mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-19 18:26:12 +00:00
feat: ensure nfqueue server graceful close
This commit is contained in:
parent
8abb282547
commit
d403f2b8ae
@ -3,6 +3,7 @@ package netfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -26,6 +27,7 @@ type NfqueueServer struct {
|
|||||||
attrChans []chan *nfq.Attribute
|
attrChans []chan *nfq.Attribute
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
Nf *nfq.Nfqueue
|
Nf *nfq.Nfqueue
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NfqueueServer) Start() error {
|
func (s *NfqueueServer) Start() error {
|
||||||
@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
s.cancel = cancel
|
||||||
|
|
||||||
// Initialize worker channels and start worker goroutines
|
// Initialize worker channels and start worker goroutines
|
||||||
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
|
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
|
||||||
for i := 0; i < s.NumWorkers; i++ {
|
for i := 0; i < s.NumWorkers; i++ {
|
||||||
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
|
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go s.worker(ctx, i, s.attrChans[i])
|
go s.worker(i, s.attrChans[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register callback function
|
// Register callback function
|
||||||
@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
|
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
|
||||||
}
|
}
|
||||||
|
} else if errors.Is(ctx.Err(), context.Canceled) {
|
||||||
|
slog.Info("Nfqueue context canceled, stopping nfqueue handler")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Error in nfqueue handler", slog.Any("error", e))
|
slog.Error("Error in nfqueue handler", slog.Any("error", e))
|
||||||
}
|
}
|
||||||
@ -118,38 +122,39 @@ func (s *NfqueueServer) Start() error {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Close all worker channels
|
return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
|
||||||
for i := 0; i < s.NumWorkers; i++ {
|
|
||||||
close(s.attrChans[i])
|
|
||||||
}
|
|
||||||
s.wg.Wait()
|
|
||||||
return fmt.Errorf("failed to register nfqueue handler: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until context is done
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
// Cleanup: close all worker channels and wait for workers to finish
|
|
||||||
for i := 0; i < s.NumWorkers; i++ {
|
|
||||||
close(s.attrChans[i])
|
|
||||||
}
|
|
||||||
s.wg.Wait()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker processes packets from its assigned channel
|
func (s *NfqueueServer) Close() {
|
||||||
func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) {
|
if s.cancel != nil {
|
||||||
defer s.wg.Done()
|
s.cancel()
|
||||||
slog.Debug("Worker %d started", slog.Int("workerID", workerID))
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
slog.Debug("Worker stopping", slog.Int("workerID", workerID))
|
|
||||||
return
|
|
||||||
case a, ok := <-aChan:
|
|
||||||
if !ok {
|
|
||||||
slog.Debug("Worker channel closed", slog.Int("workerID", workerID))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(s.attrChans); i++ {
|
||||||
|
if s.attrChans[i] != nil {
|
||||||
|
close(s.attrChans[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
if s.Nf != nil {
|
||||||
|
_ = s.Nf.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker processes packets from its assigned channel
|
||||||
|
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID))
|
||||||
|
slog.Info("Nfqueue worker started", slog.Int("workerID", workerID))
|
||||||
|
|
||||||
|
for a := range aChan {
|
||||||
if ok := attributeSanityCheck(a); !ok {
|
if ok := attributeSanityCheck(a); !ok {
|
||||||
if a.PacketID != nil {
|
if a.PacketID != nil {
|
||||||
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
||||||
@ -168,7 +173,6 @@ func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *
|
|||||||
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
|
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
|
||||||
s.HandlePacket(packet)
|
s.HandlePacket(packet)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {
|
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {
|
||||||
|
|||||||
@ -52,12 +52,10 @@ func (s *Server) Start() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Close() (err error) {
|
func (s *Server) Close() error {
|
||||||
err = s.Firewall.Cleanup()
|
err := s.Firewall.Cleanup()
|
||||||
if err != nil {
|
s.nfqServer.Close()
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handlePacket processes a single NFQUEUE packet
|
// handlePacket processes a single NFQUEUE packet
|
||||||
|
|||||||
@ -66,8 +66,9 @@ func (s *Server) Start() (err error) {
|
|||||||
return s.nfqServer.Start()
|
return s.nfqServer.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Close() (err error) {
|
func (s *Server) Close() error {
|
||||||
err = s.Firewall.Cleanup()
|
err := s.Firewall.Cleanup()
|
||||||
|
s.nfqServer.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user