diff --git a/src/internal/netfilter/firewall.go b/src/internal/netfilter/firewall.go index 3aaccb3..e471529 100644 --- a/src/internal/netfilter/firewall.go +++ b/src/internal/netfilter/firewall.go @@ -1,8 +1,12 @@ package netfilter import ( + "encoding/json" + "errors" "fmt" + "log" "log/slog" + "net" "os/exec" "os/user" @@ -39,7 +43,7 @@ var LAN_CIDRS = []string{ "127.0.0.0/8", "169.254.0.0/16", "172.16.0.0/12", - "192.168.1.0/24", + "192.168.0.0/16", "224.0.0.0/4", "240.0.0.0/4", } @@ -95,6 +99,7 @@ var ( func init() { initSkipGids() + initLanCidrs() } type Firewall struct { @@ -109,6 +114,7 @@ func (f *Firewall) Setup(cfg *config.Config) (err error) { _ = f.Cleanup() backend := detectFirewallBackend(cfg) slog.Info("Setup firewall", slog.String("backend", backend)) + slog.Info("Exempt LAN CIDRs", slog.String("cidrs", fmt.Sprintf("%v", LAN_CIDRS))) switch backend { case NFT: if f.NftSetup == nil { @@ -284,6 +290,50 @@ func detectFirewallBackend(cfg *config.Config) string { } } +func getWanNexthops() ([]string, error) { + out, err := exec.Command("ubus", "call", "network.interface.wan", "status").Output() + if err != nil { + return nil, err + } + var result struct { + Route []struct { + NextHop string `json:"nexthop"` + } `json:"route"` + } + if err := json.Unmarshal(out, &result); err != nil { + log.Fatal(err) + } + if len(result.Route) == 0 { + return nil, errors.New("no route found for wan interface") + } + var nexthops []string + for _, route := range result.Route { + nexthops = append(nexthops, route.NextHop) + } + return nexthops, nil +} + +func getLocalIPv4CIDRs() ([]string, error) { + var cidrs []string + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + ip := ipNet.IP + if ipv4 := ip.To4(); ipv4 != nil { + cidrs = append(cidrs, fmt.Sprintf("%s/32", ipv4.String())) + } + } + + return cidrs, nil +} + func isCommandAvailable(cmd string) bool { _, err := exec.LookPath(cmd) return err == nil @@ -332,3 +382,51 @@ func initSkipGids() { SIDECAR = OC } } + +func initLanCidrs() { + // remove wan from lan cidrs + nexthops, err := getWanNexthops() + if err != nil { + return + } + + var lanRanges []net.IPNet + for _, lan := range LAN_CIDRS { + _, ipNet, err := net.ParseCIDR(lan) + if err == nil { + lanRanges = append(lanRanges, *ipNet) + } + } + + var wanIPs []net.IP + for _, nh := range nexthops { + if ip := net.ParseIP(nh); ip != nil { + wanIPs = append(wanIPs, ip) + } + } + + remove := make(map[int]struct{}) + for i, lanNet := range lanRanges { + for _, ip := range wanIPs { + if lanNet.Contains(ip) { + remove[i] = struct{}{} + break + } + } + } + + var updatedCIDRs []string + for i, lanNet := range lanRanges { + if _, ok := remove[i]; !ok { + updatedCIDRs = append(updatedCIDRs, lanNet.String()) + } + } + + LAN_CIDRS = updatedCIDRs + + localCIDRs, err := getLocalIPv4CIDRs() + if err != nil { + return + } + LAN_CIDRS = append(LAN_CIDRS, localCIDRs...) +}