feat: add wan nexthop filtering

This commit is contained in:
SunBK201 2025-11-22 23:20:14 +08:00
parent 83235a79ff
commit df5c4d9efe

View File

@ -1,8 +1,12 @@
package netfilter
import (
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"net"
"os/exec"
"os/user"
@ -39,7 +43,7 @@ var LAN_CIDRS = []string{
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.168.1.0/24",
"192.168.0.0/16",
"224.0.0.0/4",
"240.0.0.0/4",
}
@ -95,6 +99,7 @@ var (
func init() {
initSkipGids()
initLanCidrs()
}
type Firewall struct {
@ -109,6 +114,7 @@ func (f *Firewall) Setup(cfg *config.Config) (err error) {
_ = f.Cleanup()
backend := detectFirewallBackend(cfg)
slog.Info("Setup firewall", slog.String("backend", backend))
slog.Info("Exempt LAN CIDRs", slog.String("cidrs", fmt.Sprintf("%v", LAN_CIDRS)))
switch backend {
case NFT:
if f.NftSetup == nil {
@ -284,6 +290,50 @@ func detectFirewallBackend(cfg *config.Config) string {
}
}
func getWanNexthops() ([]string, error) {
out, err := exec.Command("ubus", "call", "network.interface.wan", "status").Output()
if err != nil {
return nil, err
}
var result struct {
Route []struct {
NextHop string `json:"nexthop"`
} `json:"route"`
}
if err := json.Unmarshal(out, &result); err != nil {
log.Fatal(err)
}
if len(result.Route) == 0 {
return nil, errors.New("no route found for wan interface")
}
var nexthops []string
for _, route := range result.Route {
nexthops = append(nexthops, route.NextHop)
}
return nexthops, nil
}
func getLocalIPv4CIDRs() ([]string, error) {
var cidrs []string
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipNet.IP
if ipv4 := ip.To4(); ipv4 != nil {
cidrs = append(cidrs, fmt.Sprintf("%s/32", ipv4.String()))
}
}
return cidrs, nil
}
func isCommandAvailable(cmd string) bool {
_, err := exec.LookPath(cmd)
return err == nil
@ -332,3 +382,51 @@ func initSkipGids() {
SIDECAR = OC
}
}
func initLanCidrs() {
// remove wan from lan cidrs
nexthops, err := getWanNexthops()
if err != nil {
return
}
var lanRanges []net.IPNet
for _, lan := range LAN_CIDRS {
_, ipNet, err := net.ParseCIDR(lan)
if err == nil {
lanRanges = append(lanRanges, *ipNet)
}
}
var wanIPs []net.IP
for _, nh := range nexthops {
if ip := net.ParseIP(nh); ip != nil {
wanIPs = append(wanIPs, ip)
}
}
remove := make(map[int]struct{})
for i, lanNet := range lanRanges {
for _, ip := range wanIPs {
if lanNet.Contains(ip) {
remove[i] = struct{}{}
break
}
}
}
var updatedCIDRs []string
for i, lanNet := range lanRanges {
if _, ok := remove[i]; !ok {
updatedCIDRs = append(updatedCIDRs, lanNet.String())
}
}
LAN_CIDRS = updatedCIDRs
localCIDRs, err := getLocalIPv4CIDRs()
if err != nil {
return
}
LAN_CIDRS = append(LAN_CIDRS, localCIDRs...)
}