feat: ensure nfqueue server graceful close

This commit is contained in:
SunBK201 2025-11-25 16:52:20 +08:00
parent 8abb282547
commit d403f2b8ae
3 changed files with 56 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package netfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"log/slog"
@ -26,6 +27,7 @@ type NfqueueServer struct {
attrChans []chan *nfq.Attribute
wg sync.WaitGroup
Nf *nfq.Nfqueue
cancel context.CancelFunc
}
func (s *NfqueueServer) Start() error {
@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error {
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.cancel = cancel
// Initialize worker channels and start worker goroutines
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
for i := 0; i < s.NumWorkers; i++ {
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
s.wg.Add(1)
go s.worker(ctx, i, s.attrChans[i])
go s.worker(i, s.attrChans[i])
}
// Register callback function
@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error {
if err != nil {
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
}
} else if errors.Is(ctx.Err(), context.Canceled) {
slog.Info("Nfqueue context canceled, stopping nfqueue handler")
} else {
slog.Error("Error in nfqueue handler", slog.Any("error", e))
}
@ -118,57 +122,57 @@ func (s *NfqueueServer) Start() error {
},
)
if err != nil {
// Close all worker channels
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return fmt.Errorf("failed to register nfqueue handler: %w", err)
return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
}
// Wait until context is done
<-ctx.Done()
// Cleanup: close all worker channels and wait for workers to finish
for i := 0; i < s.NumWorkers; i++ {
close(s.attrChans[i])
}
s.wg.Wait()
return nil
}
// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) {
defer s.wg.Done()
slog.Debug("Worker %d started", slog.Int("workerID", workerID))
for {
select {
case <-ctx.Done():
slog.Debug("Worker stopping", slog.Int("workerID", workerID))
return
case a, ok := <-aChan:
if !ok {
slog.Debug("Worker channel closed", slog.Int("workerID", workerID))
return
}
if ok := attributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
}
slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID))
return
}
packet, err := NewPacket(a)
if err != nil {
slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err))
if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
}
continue
}
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
s.HandlePacket(packet)
func (s *NfqueueServer) Close() {
if s.cancel != nil {
s.cancel()
}
for i := 0; i < len(s.attrChans); i++ {
if s.attrChans[i] != nil {
close(s.attrChans[i])
}
}
s.wg.Wait()
if s.Nf != nil {
_ = s.Nf.Close()
}
}
// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
defer s.wg.Done()
defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID))
slog.Info("Nfqueue worker started", slog.Int("workerID", workerID))
for a := range aChan {
if ok := attributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
}
slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID))
return
}
packet, err := NewPacket(a)
if err != nil {
slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err))
if a.PacketID != nil {
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
}
continue
}
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
s.HandlePacket(packet)
}
}
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {

View File

@ -52,12 +52,10 @@ func (s *Server) Start() (err error) {
return nil
}
func (s *Server) Close() (err error) {
err = s.Firewall.Cleanup()
if err != nil {
return err
}
return nil
func (s *Server) Close() error {
err := s.Firewall.Cleanup()
s.nfqServer.Close()
return err
}
// handlePacket processes a single NFQUEUE packet

View File

@ -66,8 +66,9 @@ func (s *Server) Start() (err error) {
return s.nfqServer.Start()
}
func (s *Server) Close() (err error) {
err = s.Firewall.Cleanup()
func (s *Server) Close() error {
err := s.Firewall.Cleanup()
s.nfqServer.Close()
return err
}