mirror of
https://github.com/SunBK201/UA3F.git
synced 2025-12-16 08:44:29 +00:00
feat: ensure nfqueue server graceful close
This commit is contained in:
parent
8abb282547
commit
d403f2b8ae
@ -3,6 +3,7 @@ package netfilter
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
@ -26,6 +27,7 @@ type NfqueueServer struct {
|
||||
attrChans []chan *nfq.Attribute
|
||||
wg sync.WaitGroup
|
||||
Nf *nfq.Nfqueue
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (s *NfqueueServer) Start() error {
|
||||
@ -80,14 +82,14 @@ func (s *NfqueueServer) Start() error {
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.cancel = cancel
|
||||
|
||||
// Initialize worker channels and start worker goroutines
|
||||
s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
|
||||
for i := 0; i < s.NumWorkers; i++ {
|
||||
s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
|
||||
s.wg.Add(1)
|
||||
go s.worker(ctx, i, s.attrChans[i])
|
||||
go s.worker(i, s.attrChans[i])
|
||||
}
|
||||
|
||||
// Register callback function
|
||||
@ -111,6 +113,8 @@ func (s *NfqueueServer) Start() error {
|
||||
if err != nil {
|
||||
slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
|
||||
}
|
||||
} else if errors.Is(ctx.Err(), context.Canceled) {
|
||||
slog.Info("Nfqueue context canceled, stopping nfqueue handler")
|
||||
} else {
|
||||
slog.Error("Error in nfqueue handler", slog.Any("error", e))
|
||||
}
|
||||
@ -118,57 +122,57 @@ func (s *NfqueueServer) Start() error {
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Close all worker channels
|
||||
for i := 0; i < s.NumWorkers; i++ {
|
||||
close(s.attrChans[i])
|
||||
}
|
||||
s.wg.Wait()
|
||||
return fmt.Errorf("failed to register nfqueue handler: %w", err)
|
||||
return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
|
||||
}
|
||||
|
||||
// Wait until context is done
|
||||
<-ctx.Done()
|
||||
// Cleanup: close all worker channels and wait for workers to finish
|
||||
for i := 0; i < s.NumWorkers; i++ {
|
||||
close(s.attrChans[i])
|
||||
}
|
||||
s.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// worker processes packets from its assigned channel
|
||||
func (s *NfqueueServer) worker(ctx context.Context, workerID int, aChan <-chan *nfq.Attribute) {
|
||||
defer s.wg.Done()
|
||||
slog.Debug("Worker %d started", slog.Int("workerID", workerID))
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("Worker stopping", slog.Int("workerID", workerID))
|
||||
return
|
||||
case a, ok := <-aChan:
|
||||
if !ok {
|
||||
slog.Debug("Worker channel closed", slog.Int("workerID", workerID))
|
||||
return
|
||||
}
|
||||
if ok := attributeSanityCheck(a); !ok {
|
||||
if a.PacketID != nil {
|
||||
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
||||
}
|
||||
slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID))
|
||||
return
|
||||
}
|
||||
packet, err := NewPacket(a)
|
||||
if err != nil {
|
||||
slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err))
|
||||
if a.PacketID != nil {
|
||||
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
||||
}
|
||||
continue
|
||||
}
|
||||
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
|
||||
s.HandlePacket(packet)
|
||||
func (s *NfqueueServer) Close() {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
for i := 0; i < len(s.attrChans); i++ {
|
||||
if s.attrChans[i] != nil {
|
||||
close(s.attrChans[i])
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
if s.Nf != nil {
|
||||
_ = s.Nf.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// worker processes packets from its assigned channel
|
||||
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
|
||||
defer s.wg.Done()
|
||||
defer slog.Info("Nfqueue worker stopped", slog.Int("workerID", workerID))
|
||||
slog.Info("Nfqueue worker started", slog.Int("workerID", workerID))
|
||||
|
||||
for a := range aChan {
|
||||
if ok := attributeSanityCheck(a); !ok {
|
||||
if a.PacketID != nil {
|
||||
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
||||
}
|
||||
slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID))
|
||||
return
|
||||
}
|
||||
packet, err := NewPacket(a)
|
||||
if err != nil {
|
||||
slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err))
|
||||
if a.PacketID != nil {
|
||||
_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
|
||||
}
|
||||
continue
|
||||
}
|
||||
slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
|
||||
s.HandlePacket(packet)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {
|
||||
|
||||
@ -52,12 +52,10 @@ func (s *Server) Start() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Close() (err error) {
|
||||
err = s.Firewall.Cleanup()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
func (s *Server) Close() error {
|
||||
err := s.Firewall.Cleanup()
|
||||
s.nfqServer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// handlePacket processes a single NFQUEUE packet
|
||||
|
||||
@ -66,8 +66,9 @@ func (s *Server) Start() (err error) {
|
||||
return s.nfqServer.Start()
|
||||
}
|
||||
|
||||
func (s *Server) Close() (err error) {
|
||||
err = s.Firewall.Cleanup()
|
||||
func (s *Server) Close() error {
|
||||
err := s.Firewall.Cleanup()
|
||||
s.nfqServer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user