feat: ensure nfqueue server graceful close

This commit is contained in:
SunBK201 2025-11-25 16:52:20 +08:00
parent 8abb282547
commit d403f2b8ae
3 changed files with 56 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package netfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"log/slog"
@ -26,6 +27,7 @@ type NfqueueServer struct {
attrChans []chan *nfq.Attribute
wg sync.WaitGroup
Nf *nfq.Nfqueue
cancel context.CancelFunc
}
func (s *NfqueueServer) Start() error {
@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error {
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.cancel = cancel
// Initialize worker channels and start worker goroutines
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
for i := 0; i < s.NumWorkers; i++ {
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
s.wg.Add(1)
go s.worker(ctx, i, s.attrChans[i])
go s.worker(i, s.attrChans[i])
}
// Register callback function
@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error {
if err != nil {
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
}
} else if errors.Is(ctx.Err(), context.Canceled) {
slog.Info("Nfqueue context canceled, stopping nfqueue handler")
} else {
slog.Error("Error in nfqueue handler", slog.Any("error", e))
}
@ -118,38 +122,39 @@ func (s *NfqueueServer) Start() error {
},
)
if err != nil {
// Close all worker channels
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return fmt.Errorf("failed to register nfqueue handler: %w", err)
return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
}
// Wait until context is done
<-ctx.Done()
// Cleanup: close all worker channels and wait for workers to finish
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return nil
}
// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) {
defer s.wg.Done()
slog.Debug("Worker %d started", slog.Int("workerID", workerID))
for {
select {
case <-ctx.Done():
slog.Debug("Worker stopping", slog.Int("workerID", workerID))
return
case a, ok := <-aChan:
if !ok {
slog.Debug("Worker channel closed", slog.Int("workerID", workerID))
return
func (s *NfqueueServer) Close() {
if s.cancel != nil {
s.cancel()
}
for i := 0; i < len(s.attrChans); i++ {
if s.attrChans[i] != nil {
close(s.attrChans[i])
}
}
s.wg.Wait()
if s.Nf != nil {
_ = s.Nf.Close()
}
}
// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
defer s.wg.Done()
defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID))
slog.Info("Nfqueue worker started", slog.Int("workerID", workerID))
for a := range aChan {
if ok := attributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
@ -169,7 +174,6 @@ func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *
s.HandlePacket(packet)
}
}
}
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {
var flowID uint32

View File

@ -52,13 +52,11 @@ func (s *Server) Start() (err error) {
return nil
}
func (s *Server) Close() (err error) {
err = s.Firewall.Cleanup()
if err != nil {
func (s *Server) Close() error {
err := s.Firewall.Cleanup()
s.nfqServer.Close()
return err
}
return nil
}
// handlePacket processes a single NFQUEUE packet
func (s *Server) handlePacket(packet *netfilter.Packet) {

View File

@ -66,8 +66,9 @@ func (s *Server) Start() (err error) {
return s.nfqServer.Start()
}
func (s *Server) Close() (err error) {
err = s.Firewall.Cleanup()
func (s *Server) Close() error {
err := s.Firewall.Cleanup()
s.nfqServer.Close()
return err
}