From d403f2b8ae09bd1120897da58f7cf9abe4018b06 Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Tue, 25 Nov 2025 16:52:20 +0800 Subject: [PATCH] feat: ensure nfqueue server graceful close --- src/internal/netfilter/nfqueue.go | 94 ++++++++++---------- src/internal/server/netlink/netlink_linux.go | 10 +-- src/internal/server/nfqueue/nfqueue_linux.go | 5 +- 3 files changed, 56 insertions(+), 53 deletions(-) diff --git a/src/internal/netfilter/nfqueue.go b/src/internal/netfilter/nfqueue.go index 10762de..bf7afd4 100644 --- a/src/internal/netfilter/nfqueue.go +++ b/src/internal/netfilter/nfqueue.go @@ -3,6 +3,7 @@ package netfilter import ( "context" "encoding/binary" + "errors" "fmt" "hash/fnv" "log/slog" @@ -26,6 +27,7 @@ type NfqueueServer struct { attrChans []chan *nfq.Attribute wg sync.WaitGroup Nf *nfq.Nfqueue + cancel context.CancelFunc } func (s *NfqueueServer) Start() error { @@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error { } ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + s.cancel = cancel // Initialize worker channels and start worker goroutines s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers) for i := 0; i < s.NumWorkers; i++ { s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen) s.wg.Add(1) - go s.worker(ctx, i, s.attrChans[i]) + go s.worker(i, s.attrChans[i]) } // Register callback function @@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error { if err != nil { slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err)) } + } else if errors.Is(ctx.Err(), context.Canceled) { + slog.Info("Nfqueue context canceled, stopping nfqueue handler") } else { slog.Error("Error in nfqueue handler", slog.Any("error", e)) } @@ -118,57 +122,57 @@ func (s *NfqueueServer) Start() error { }, ) if err != nil { - // Close all worker channels - for i := 0; i < s.NumWorkers; i++ { - close(s.attrChans[i]) - } - s.wg.Wait() - return fmt.Errorf("failed to register nfqueue handler: %w", err) + return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err) } - // Wait until context is done <-ctx.Done() - // Cleanup: close all worker channels and wait for workers to finish - for i := 0; i < s.NumWorkers; i++ { - close(s.attrChans[i]) - } - s.wg.Wait() + return nil } -// worker processes packets from its assigned channel -func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) { - defer s.wg.Done() - slog.Debug("Worker %d started", slog.Int("workerID", workerID)) - for { - select { - case <-ctx.Done(): - slog.Debug("Worker stopping", slog.Int("workerID", workerID)) - return - case a, ok := <-aChan: - if !ok { - slog.Debug("Worker channel closed", slog.Int("workerID", workerID)) - return - } - if ok := attributeSanityCheck(a); !ok { - if a.PacketID != nil { - _ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept) - } - slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID)) - return - } - packet, err := NewPacket(a) - if err != nil { - slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err)) - if a.PacketID != nil { - _ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept) - } - continue - } - slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr)) - s.HandlePacket(packet) +func (s *NfqueueServer) Close() { + if s.cancel != nil { + s.cancel() + } + + for i := 0; i < len(s.attrChans); i++ { + if s.attrChans[i] != nil { + close(s.attrChans[i]) } } + + s.wg.Wait() + + if s.Nf != nil { + _ = s.Nf.Close() + } +} + +// worker processes packets from its assigned channel +func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) { + defer s.wg.Done() + defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID)) + slog.Info("Nfqueue worker started", slog.Int("workerID", workerID)) + + for a := range aChan { + if ok := attributeSanityCheck(a); !ok { + if a.PacketID != nil { + _ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept) + } + slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID)) + return + } + packet, err := NewPacket(a) + if err != nil { + slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err)) + if a.PacketID != nil { + _ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept) + } + continue + } + slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr)) + s.HandlePacket(packet) + } } func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int { diff --git a/src/internal/server/netlink/netlink_linux.go b/src/internal/server/netlink/netlink_linux.go index 4286a55..faeaee2 100644 --- a/src/internal/server/netlink/netlink_linux.go +++ b/src/internal/server/netlink/netlink_linux.go @@ -52,12 +52,10 @@ func (s *Server) Start() (err error) { return nil } -func (s *Server) Close() (err error) { - err = s.Firewall.Cleanup() - if err != nil { - return err - } - return nil +func (s *Server) Close() error { + err := s.Firewall.Cleanup() + s.nfqServer.Close() + return err } // handlePacket processes a single NFQUEUE packet diff --git a/src/internal/server/nfqueue/nfqueue_linux.go b/src/internal/server/nfqueue/nfqueue_linux.go index 87e194e..3a17763 100644 --- a/src/internal/server/nfqueue/nfqueue_linux.go +++ b/src/internal/server/nfqueue/nfqueue_linux.go @@ -66,8 +66,9 @@ func (s *Server) Start() (err error) { return s.nfqServer.Start() } -func (s *Server) Close() (err error) { - err = s.Firewall.Cleanup() +func (s *Server) Close() error { + err := s.Firewall.Cleanup() + s.nfqServer.Close() return err }