feat: ensure nfqueue server graceful close

This commit is contained in:
SunBK201 2025-11-25 16:52:20 +08:00
parent 8abb282547
commit d403f2b8ae
3 changed files with 56 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package netfilter
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"log/slog" "log/slog"
@ -26,6 +27,7 @@ type NfqueueServer struct {
attrChans []chan *nfq.Attribute attrChans []chan *nfq.Attribute
wg sync.WaitGroup wg sync.WaitGroup
Nf *nfq.Nfqueue Nf *nfq.Nfqueue
cancel context.CancelFunc
} }
func (s *NfqueueServer) Start() error { func (s *NfqueueServer) Start() error {
@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error {
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() s.cancel = cancel
// Initialize worker channels and start worker goroutines // Initialize worker channels and start worker goroutines
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers) s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
for i := 0; i < s.NumWorkers; i++ { for i := 0; i < s.NumWorkers; i++ {
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen) s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
s.wg.Add(1) s.wg.Add(1)
go s.worker(ctx, i, s.attrChans[i]) go s.worker(i, s.attrChans[i])
} }
// Register callback function // Register callback function
@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error {
if err != nil { if err != nil {
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err)) slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
} }
} else if errors.Is(ctx.Err(), context.Canceled) {
slog.Info("Nfqueue context canceled, stopping nfqueue handler")
} else { } else {
slog.Error("Error in nfqueue handler", slog.Any("error", e)) slog.Error("Error in nfqueue handler", slog.Any("error", e))
} }
@ -118,38 +122,39 @@ func (s *NfqueueServer) Start() error {
}, },
) )
if err != nil { if err != nil {
// Close all worker channels return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return fmt.Errorf("failed to register nfqueue handler: %w", err)
} }
// Wait until context is done
<-ctx.Done() <-ctx.Done()
// Cleanup: close all worker channels and wait for workers to finish
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return nil return nil
} }
// worker processes packets from its assigned channel func (s *NfqueueServer) Close() {
func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) { if s.cancel != nil {
defer s.wg.Done() s.cancel()
slog.Debug("Worker %d started", slog.Int("workerID", workerID))
for {
select {
case <-ctx.Done():
slog.Debug("Worker stopping", slog.Int("workerID", workerID))
return
case a, ok := <-aChan:
if !ok {
slog.Debug("Worker channel closed", slog.Int("workerID", workerID))
return
} }
for i := 0; i < len(s.attrChans); i++ {
if s.attrChans[i] != nil {
close(s.attrChans[i])
}
}
s.wg.Wait()
if s.Nf != nil {
_ = s.Nf.Close()
}
}
// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
defer s.wg.Done()
defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID))
slog.Info("Nfqueue worker started", slog.Int("workerID", workerID))
for a := range aChan {
if ok := attributeSanityCheck(a); !ok { if ok := attributeSanityCheck(a); !ok {
if a.PacketID != nil { if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept) _ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
@ -168,7 +173,6 @@ func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr)) slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
s.HandlePacket(packet) s.HandlePacket(packet)
} }
}
} }
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int { func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {

View File

@ -52,12 +52,10 @@ func (s *Server) Start() (err error) {
return nil return nil
} }
func (s *Server) Close() (err error) { func (s *Server) Close() error {
err = s.Firewall.Cleanup() err := s.Firewall.Cleanup()
if err != nil { s.nfqServer.Close()
return err return err
}
return nil
} }
// handlePacket processes a single NFQUEUE packet // handlePacket processes a single NFQUEUE packet

View File

@ -66,8 +66,9 @@ func (s *Server) Start() (err error) {
return s.nfqServer.Start() return s.nfqServer.Start()
} }
func (s *Server) Close() (err error) { func (s *Server) Close() error {
err = s.Firewall.Cleanup() err := s.Firewall.Cleanup()
s.nfqServer.Close()
return err return err
} }