diff --git a/src/internal/netfilter/firewall.go b/src/internal/netfilter/firewall.go index 5843c4d..ed7d57e 100644 --- a/src/internal/netfilter/firewall.go +++ b/src/internal/netfilter/firewall.go @@ -13,7 +13,6 @@ import ( "strings" "github.com/coreos/go-iptables/iptables" - "github.com/gonetx/ipset" "github.com/sunbk201/ua3f/internal/config" "sigs.k8s.io/knftables" ) @@ -25,6 +24,7 @@ const ( const ( LANSET = "UA3F_LAN" + SKIP_IPSET = "UA3F_SKIP_IPSET" SKIP_PORTS = "22,51080,51090" FAKEIP_RANGE = "198.18.0.0/16,198.18.0.1/15,28.0.0.1/8" HELPER_QUEUE = 10301 @@ -63,58 +63,18 @@ var LAN6_CIDRS = []string{ "ff00::/8", } -var ( - IptRuleIgnoreBrLAN = []string{ - "!", "-i", "br-lan", - "-j", "RETURN", - } - IptRuleIgnoreReply = []string{ - "-m", "conntrack", - "--ctdir", "REPLY", - "-j", "RETURN", - } - IptRuleIgnoreLAN = []string{ - "-m", "set", - "--match-set", LANSET, "dst", - "-j", "RETURN", - } - IptRuleIgnorePorts = []string{ - "-p", "tcp", - "-m", "multiport", - "--dports", SKIP_PORTS, - "-j", "RETURN", - } -) -var ( - NftRuleIgnoreNotTCP = knftables.Concat( - "meta l4proto != tcp", - "return", - ) - NftRuleIgnoreNotBrLAN = knftables.Concat( - "iifname != \"br-lan\"", - "return", - ) - NftRuleIgnoreReply = knftables.Concat( - "ct direction reply", - "return", - ) - NftRuleIgnoreLAN = knftables.Concat( - fmt.Sprintf("ip daddr @%s", LANSET), - "return", - ) - NftRuleIgnoreLAN6 = knftables.Concat( - fmt.Sprintf("ip6 daddr @%s", LANSET+"_6"), - "return", - ) - NftRuleIgnorePorts = knftables.Concat( - fmt.Sprintf("tcp dport { %s }", SKIP_PORTS), - "return", - ) - NftRuleIgnoreFakeIP = knftables.Concat( - fmt.Sprintf("ip daddr { %s }", FAKEIP_RANGE), - "return", - ) -) +var SKIP_DOMAINS = []string{ + "st.dl.eccdnx.com", + "st.dl.bscstorage.net", + "st.dl.pinyuncloud.com", + "dl.steam.clngaa.com", + "cdn-ws.content.steamchina.com", + "cdn-qc.content.steamchina.com", + "cdn-ali.content.steamchina.com", + "xz.pphimalayanrt.com", + "lv.queniujq.cn", + "alibaba.cdn.steampipe.steamcontent.com", +} func init() { initSkipGids() @@ -125,8 +85,10 @@ type Firewall struct { Nftable *knftables.Table NftSetup func() error NftCleanup func() error + NftWatch func() IptSetup func() error IptCleanup func() error + IptWatch func() } func (f *Firewall) Setup(cfg *config.Config) (err error) { @@ -140,11 +102,17 @@ func (f *Firewall) Setup(cfg *config.Config) (err error) { return fmt.Errorf("nftables setup function is nil") } err = f.NftSetup() + if f.NftWatch != nil { + f.NftWatch() + } case IPT: if f.IptSetup == nil { return fmt.Errorf("iptables setup function is nil") } err = f.IptSetup() + if f.IptWatch != nil { + f.IptWatch() + } default: err = fmt.Errorf("unsupported or no firewall backend: %s", backend) } @@ -166,102 +134,6 @@ func (f *Firewall) Cleanup() error { return nil } -func (f *Firewall) DumpNFTables() { - cmd := exec.Command("nft", "--handle", "list", "ruleset") - output, err := cmd.CombinedOutput() - if err != nil { - return - } - slog.Info("nftables ruleset:\n" + string(output)) -} - -func (f *Firewall) DumpIPTables() { - var tables = []string{"filter", "nat", "mangle", "raw"} - for _, table := range tables { - cmd := exec.Command("iptables", "-t", table, "-L", "-v", "-n", "--line-numbers") - output, err := cmd.CombinedOutput() - if err != nil { - continue - } - slog.Debug(fmt.Sprintf("iptables table(%s):\n%s", table, string(output))) - } -} - -func (f *Firewall) NftSetLanIP(tx *knftables.Transaction, table *knftables.Table) { - ipset := &knftables.Set{ - Name: LANSET, - Table: table.Name, - Family: table.Family, - Type: "ipv4_addr", - Flags: []knftables.SetFlag{ - knftables.IntervalFlag, - }, - AutoMerge: knftables.PtrTo(true), - } - tx.Add(ipset) - - for _, cidr := range LAN_CIDRS { - iplan := &knftables.Element{ - Table: table.Name, - Family: table.Family, - Set: ipset.Name, - Key: []string{cidr}, - } - tx.Add(iplan) - } -} - -func (f *Firewall) NftSetLanIP6(tx *knftables.Transaction, table *knftables.Table) { - ipset := &knftables.Set{ - Name: LANSET + "_6", - Table: table.Name, - Family: table.Family, - Type: "ipv6_addr", - Flags: []knftables.SetFlag{ - knftables.IntervalFlag, - }, - AutoMerge: knftables.PtrTo(true), - } - tx.Add(ipset) - - for _, cidr := range LAN6_CIDRS { - ip6lan := &knftables.Element{ - Table: table.Name, - Family: table.Family, - Set: ipset.Name, - Key: []string{cidr}, - } - tx.Add(ip6lan) - } -} - -func (f *Firewall) IptSetLanIP() error { - if err := ipset.Check(); err != nil { - return err - } - set, err := ipset.New( - LANSET, - ipset.HashNet, - ipset.Exist(false), - ipset.Family(ipset.Inet), - ) - if err != nil { - return err - } - - for _, cidr := range LAN_CIDRS { - err := set.Add(cidr) - if err != nil { - return err - } - } - return nil -} - -func (f *Firewall) IptDeleteLanIP() error { - return ipset.Destroy(LANSET) -} - func (f *Firewall) AddTproxyRoute(fwmark, routeTable string) error { sysctlCmds := [][]string{ {"-w", "net.bridge.bridge-nf-call-iptables=0"}, @@ -340,6 +212,27 @@ func detectFirewallBackend(cfg *config.Config) string { } } +func (f *Firewall) resolveDomains(domains []string) (v4 []string, v6 []string) { + var ipv4Addrs []string + var ipv6Addrs []string + + for _, domain := range domains { + ips, err := net.LookupIP(domain) + if err != nil { + slog.Warn("net.LookupIP", slog.String("domain", domain), slog.Any("error", err)) + continue + } + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipv4Addrs = append(ipv4Addrs, ipv4.String()) + } else if ipv6 := ip.To16(); ipv6 != nil { + ipv6Addrs = append(ipv6Addrs, ipv6.String()) + } + } + } + return ipv4Addrs, ipv6Addrs +} + func getWanNexthops() ([]string, error) { out, err := exec.Command("ubus", "call", "network.interface.wan", "status").Output() if err != nil { diff --git a/src/internal/netfilter/iptables.go b/src/internal/netfilter/iptables.go new file mode 100644 index 0000000..7936059 --- /dev/null +++ b/src/internal/netfilter/iptables.go @@ -0,0 +1,151 @@ +package netfilter + +import ( + "fmt" + "log/slog" + "os/exec" + "time" + + "github.com/gonetx/ipset" +) + +var ( + IptRuleIgnoreBrLAN = []string{ + "!", "-i", "br-lan", + "-j", "RETURN", + } + IptRuleIgnoreReply = []string{ + "-m", "conntrack", + "--ctdir", "REPLY", + "-j", "RETURN", + } + IptRuleIgnoreLAN = []string{ + "-m", "set", + "--match-set", LANSET, "dst", + "-j", "RETURN", + } + IptRuleIgnoreIP = []string{ + "-m", "set", + "--match-set", SKIP_IPSET, "dst", + "-j", "RETURN", + } + IptRuleIgnorePorts = []string{ + "-p", "tcp", + "-m", "multiport", + "--dports", SKIP_PORTS, + "-j", "RETURN", + } +) + +func (f *Firewall) DumpIPTables() { + var tables = []string{"filter", "nat", "mangle", "raw"} + for _, table := range tables { + cmd := exec.Command("iptables", "-t", table, "-L", "-v", "-n", "--line-numbers") + output, err := cmd.CombinedOutput() + if err != nil { + continue + } + slog.Debug(fmt.Sprintf("iptables table(%s):\n%s", table, string(output))) + } +} + +func (f *Firewall) IptSetLanIP() error { + if err := ipset.Check(); err != nil { + return err + } + set, err := ipset.New( + LANSET, + ipset.HashNet, + ipset.Exist(false), + ipset.Family(ipset.Inet), + ) + if err != nil { + return err + } + + for _, cidr := range LAN_CIDRS { + err := set.Add(cidr) + if err != nil { + return err + } + } + return nil +} + +func (f *Firewall) IptDeleteLanIP() error { + return ipset.Destroy(LANSET) +} + +func (f *Firewall) IptSetSkipIP() error { + if err := ipset.Check(); err != nil { + return err + } + _, err := ipset.New( + SKIP_IPSET, + ipset.HashIp, + ipset.Family(ipset.Inet), + ipset.Timeout(time.Hour), + ipset.Exist(false), + ) + if err != nil { + return err + } + + _ = f.IptAddSkipDomains() + + return nil +} + +func (f *Firewall) IptDeleteSkipIP() error { + return ipset.Destroy(SKIP_IPSET) +} + +func (f *Firewall) IptAddSkipIP(ip string) error { + if err := ipset.Check(); err != nil { + return err + } + set, err := ipset.New( + SKIP_IPSET, + ipset.HashIp, + ipset.Family(ipset.Inet), + ipset.Timeout(time.Hour), + ipset.Exist(true), + ) + if err != nil { + return err + } + + if err := set.Add(ip); err != nil { + return err + } + return nil +} + +func (f *Firewall) IptAddSkipDomains() error { + if err := ipset.Check(); err != nil { + return err + } + set, err := ipset.New( + SKIP_IPSET, + ipset.HashIp, + ipset.Family(ipset.Inet), + ipset.Timeout(time.Hour), + ipset.Exist(true), + ) + if err != nil { + return err + } + + v4Addrs, v6Addrs := f.resolveDomains(SKIP_DOMAINS) + for _, addr := range v4Addrs { + if err := set.Add(addr); err != nil { + return err + } + } + for _, addr := range v6Addrs { + if err := set.Add(addr); err != nil { + return err + } + } + return nil +} diff --git a/src/internal/netfilter/nftables.go b/src/internal/netfilter/nftables.go new file mode 100644 index 0000000..e227aee --- /dev/null +++ b/src/internal/netfilter/nftables.go @@ -0,0 +1,199 @@ +package netfilter + +import ( + "context" + "fmt" + "log/slog" + "os/exec" + "time" + + "sigs.k8s.io/knftables" +) + +var ( + NftRuleIgnoreNotTCP = knftables.Concat( + "meta l4proto != tcp", + "return", + ) + NftRuleIgnoreNotBrLAN = knftables.Concat( + "iifname != \"br-lan\"", + "return", + ) + NftRuleIgnoreReply = knftables.Concat( + "ct direction reply", + "return", + ) + NftRuleIgnoreLAN = knftables.Concat( + fmt.Sprintf("ip daddr @%s", LANSET), + "return", + ) + NftRuleIgnoreLAN6 = knftables.Concat( + fmt.Sprintf("ip6 daddr @%s", LANSET+"_6"), + "return", + ) + NftRuleIgnoreIP = knftables.Concat( + fmt.Sprintf("ip daddr @%s", SKIP_IPSET), + "return", + ) + NftRuleIgnoreIP6 = knftables.Concat( + fmt.Sprintf("ip6 daddr @%s", SKIP_IPSET+"_6"), + "return", + ) + NftRuleIgnorePorts = knftables.Concat( + fmt.Sprintf("tcp dport { %s }", SKIP_PORTS), + "return", + ) + NftRuleIgnoreFakeIP = knftables.Concat( + fmt.Sprintf("ip daddr { %s }", FAKEIP_RANGE), + "return", + ) +) + +func (f *Firewall) DumpNFTables() { + cmd := exec.Command("nft", "--handle", "list", "ruleset") + output, err := cmd.CombinedOutput() + if err != nil { + return + } + slog.Info("nftables ruleset:\n" + string(output)) +} + +func (f *Firewall) NftSetLanIP(tx *knftables.Transaction, table *knftables.Table) { + ipset := &knftables.Set{ + Name: LANSET, + Table: table.Name, + Family: table.Family, + Type: "ipv4_addr", + Flags: []knftables.SetFlag{ + knftables.IntervalFlag, + }, + AutoMerge: knftables.PtrTo(true), + } + tx.Add(ipset) + + for _, cidr := range LAN_CIDRS { + iplan := &knftables.Element{ + Table: table.Name, + Family: table.Family, + Set: ipset.Name, + Key: []string{cidr}, + } + tx.Add(iplan) + } +} + +func (f *Firewall) NftSetLanIP6(tx *knftables.Transaction, table *knftables.Table) { + ipset := &knftables.Set{ + Name: LANSET + "_6", + Table: table.Name, + Family: table.Family, + Type: "ipv6_addr", + Flags: []knftables.SetFlag{ + knftables.IntervalFlag, + }, + AutoMerge: knftables.PtrTo(true), + } + tx.Add(ipset) + + for _, cidr := range LAN6_CIDRS { + ip6lan := &knftables.Element{ + Table: table.Name, + Family: table.Family, + Set: ipset.Name, + Key: []string{cidr}, + } + tx.Add(ip6lan) + } +} + +func (f *Firewall) NftSetSkipIP(tx *knftables.Transaction, table *knftables.Table) { + ipset := &knftables.Set{ + Name: SKIP_IPSET, + Table: table.Name, + Family: table.Family, + Type: "ipv4_addr", + Flags: []knftables.SetFlag{ + knftables.TimeoutFlag, + }, + Timeout: knftables.PtrTo(3600 * time.Second), + } + tx.Add(ipset) +} + +func (f *Firewall) NftSetSkipIP6(tx *knftables.Transaction, table *knftables.Table) { + ipset := &knftables.Set{ + Name: SKIP_IPSET + "_6", + Table: table.Name, + Family: table.Family, + Type: "ipv6_addr", + Flags: []knftables.SetFlag{ + knftables.TimeoutFlag, + }, + Timeout: knftables.PtrTo(3600 * time.Second), + } + tx.Add(ipset) +} + +func (f *Firewall) NftAddSkipIP(table *knftables.Table, addrs []string) error { + nft, err := knftables.New(table.Family, table.Name) + if err != nil { + return err + } + + tx := nft.NewTransaction() + for _, addr := range addrs { + element := &knftables.Element{ + Table: table.Name, + Family: table.Family, + Set: SKIP_IPSET, + Key: []string{addr}, + } + tx.Add(element) + } + + if err := nft.Run(context.TODO(), tx); err != nil { + return err + } + return nil +} + +func (f *Firewall) NftAddSkipIP6(table *knftables.Table, addrs []string) error { + nft, err := knftables.New(table.Family, table.Name) + if err != nil { + return err + } + + tx := nft.NewTransaction() + for _, addr := range addrs { + element := &knftables.Element{ + Table: table.Name, + Family: table.Family, + Set: SKIP_IPSET + "_6", + Key: []string{addr}, + } + tx.Add(element) + } + + if err := nft.Run(context.TODO(), tx); err != nil { + return err + } + return nil +} + +func (f *Firewall) NftAddSkipDomains() error { + v4Addrs, v6Addrs := f.resolveDomains(SKIP_DOMAINS) + + if len(v4Addrs) > 0 { + if err := f.NftAddSkipIP(f.Nftable, v4Addrs); err != nil { + slog.Warn("f.NftAddSkipIP", slog.Any("error", err)) + return err + } + } + if len(v6Addrs) > 0 { + if err := f.NftAddSkipIP6(f.Nftable, v6Addrs); err != nil { + slog.Warn("f.NftAddSkipIP6", slog.Any("error", err)) + return err + } + } + return nil +} diff --git a/src/internal/rewrite/packet.go b/src/internal/rewrite/packet.go index c69b0aa..20b66c8 100644 --- a/src/internal/rewrite/packet.go +++ b/src/internal/rewrite/packet.go @@ -13,6 +13,7 @@ type RewriteResult struct { Modified bool // Whether the packet was modified HasUA bool // Whether User-Agent was found InCache bool // Whether destination address is in cache + NeedSkip bool } // shouldRewriteUA determines if the User-Agent should be rewritten @@ -66,18 +67,18 @@ func (r *Rewriter) buildReplacement(srcAddr, dstAddr string, originalUA string, // RewritePacketUserAgent rewrites User-Agent in a raw packet payload in-place // Returns metadata about the operation -func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr string) (hasUA, modified bool) { +func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr string) (hasUA, modified, skip bool) { // Find all User-Agent positions positions, unterm := findUserAgentInPayload(payload) if unterm { log.LogInfoWithAddr(srcAddr, dstAddr, "Unterminated User-Agent found, not rewriting") - return true, false + return true, false, false } if len(positions) == 0 { log.LogDebugWithAddr(srcAddr, dstAddr, "No User-Agent found in payload") - return false, false + return false, false, false } // Replace each User-Agent value in-place @@ -93,6 +94,10 @@ func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr strin log.LogInfoWithAddr(srcAddr, dstAddr, fmt.Sprintf("Original User-Agent: %s", originalUA)) + if originalUA == "Valve/Steam HTTP Client 1.0" { + return true, false, true + } + // Check if should rewrite if !r.shouldRewriteUA(srcAddr, dstAddr, originalUA) { r.Recorder.AddRecord(&statistics.PassThroughRecord{ @@ -100,7 +105,7 @@ func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr strin DestAddr: dstAddr, UA: originalUA, }) - return true, false + return true, false, false } // Build replacement with regex matching @@ -110,14 +115,15 @@ func (r *Rewriter) RewritePacketUserAgent(payload []byte, srcAddr, dstAddr strin modified = true } } - return true, modified + return true, modified, false } // RewriteTCP rewrites the TCP packet's User-Agent if applicable func (r *Rewriter) RewriteTCP(tcp *layers.TCP, srcAddr, dstAddr string) *RewriteResult { - hasUA, modified := r.RewritePacketUserAgent(tcp.Payload, srcAddr, dstAddr) + hasUA, modified, skip := r.RewritePacketUserAgent(tcp.Payload, srcAddr, dstAddr) return &RewriteResult{ Modified: modified, HasUA: hasUA, + NeedSkip: skip, } } diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go index 3bef437..1e2bd67 100644 --- a/src/internal/rewrite/rewriter.go +++ b/src/internal/rewrite/rewriter.go @@ -31,10 +31,11 @@ type RewriteDecision struct { Action rule.Action MatchedRule *rule.Rule NeedCache bool + NeedSkip bool } func (d *RewriteDecision) ShouldRewrite() bool { - if d.NeedCache { + if d.NeedCache || d.NeedSkip { return false } return d.Action == rule.ActionReplace || @@ -191,6 +192,9 @@ func (r *Rewriter) EvaluateRewriteDecision(req *http.Request, srcAddr, destAddr if isWhitelist { log.LogInfoWithAddr(srcAddr, destAddr, fmt.Sprintf("Hit User-Agent whitelist: %s, add to cache", originalUA)) decision.NeedCache = true + if originalUA == "Valve/Steam HTTP Client 1.0" { + decision.NeedSkip = true + } } if !matches { log.LogDebugWithAddr(srcAddr, destAddr, fmt.Sprintf("Not hit User-Agent regex: %s", originalUA)) diff --git a/src/internal/server/base/connlink.go b/src/internal/server/base/connlink.go index 861be27..9289f07 100644 --- a/src/internal/server/base/connlink.go +++ b/src/internal/server/base/connlink.go @@ -10,10 +10,11 @@ import ( var one = make([]byte, 1) type ConnLink struct { - LConn net.Conn - RConn net.Conn - LAddr string - RAddr string + LConn net.Conn + RConn net.Conn + LAddr string + RAddr string + Skipped bool } func (c *ConnLink) CopyLR() { diff --git a/src/internal/server/base/packet.go b/src/internal/server/base/packet.go index e0d5128..9c9f489 100644 --- a/src/internal/server/base/packet.go +++ b/src/internal/server/base/packet.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "log/slog" + "net" nfq "github.com/florianl/go-nfqueue/v2" "github.com/google/gopacket" @@ -17,6 +18,7 @@ type Packet struct { TCP *layers.TCP SrcAddr string DstAddr string + DstIP net.IP IsIPv6 bool } @@ -57,10 +59,12 @@ func NewPacket(a *nfq.Attribute) (packet *Packet, err error) { ip6 := packet.NetworkLayer.(*layers.IPv6) packet.SrcAddr = fmt.Sprintf("%s:%d", ip6.SrcIP.String(), packet.TCP.SrcPort) packet.DstAddr = fmt.Sprintf("%s:%d", ip6.DstIP.String(), packet.TCP.DstPort) + packet.DstIP = ip6.DstIP } else { ip4 := packet.NetworkLayer.(*layers.IPv4) packet.SrcAddr = fmt.Sprintf("%s:%d", ip4.SrcIP.String(), packet.TCP.SrcPort) packet.DstAddr = fmt.Sprintf("%s:%d", ip4.DstIP.String(), packet.TCP.DstPort) + packet.DstIP = ip4.DstIP } return } diff --git a/src/internal/server/base/server.go b/src/internal/server/base/server.go index de88a59..ee2829b 100644 --- a/src/internal/server/base/server.go +++ b/src/internal/server/base/server.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "strings" "time" @@ -18,10 +19,11 @@ import ( ) type Server struct { - Cfg *config.Config - Rewriter *rewrite.Rewriter - Recorder *statistics.Recorder - Cache *expirable.LRU[string, struct{}] + Cfg *config.Config + Rewriter *rewrite.Rewriter + Recorder *statistics.Recorder + Cache *expirable.LRU[string, struct{}] + SkipIpChan chan *net.IP } func (s *Server) ServeConnLink(connLink *ConnLink) { @@ -52,6 +54,10 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { if err != nil { c.LogDebugf("ProcessLR: %s", err.Error()) } + if c.Skipped { + _ = c.CloseLR() + return + } if _, err = io.CopyBuffer(c.RConn, reader, one); err != nil { c.LogWarnf("Process io.CopyBuffer: %v", err) } @@ -114,6 +120,7 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { c.LogWarn("sniff subsequent request is not http, switch to direct forward") return } + if req, err = http.ReadRequest(reader); err != nil { err = fmt.Errorf("http.ReadRequest: %w", err) return @@ -128,6 +135,13 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { if decision.NeedCache { s.Cache.Add(c.RAddr, struct{}{}) } + if !c.Skipped && decision.NeedSkip && s.SkipIpChan != nil { + select { + case s.SkipIpChan <- &c.RConn.RemoteAddr().(*net.TCPAddr).IP: + c.Skipped = true + default: + } + } if decision.ShouldRewrite() { req = s.Rewriter.Rewrite(req, c.LAddr, c.RAddr, decision) @@ -146,5 +160,9 @@ func (s *Server) ProcessLR(c *ConnLink) (err error) { }) return } + + if c.Skipped { + return + } } } diff --git a/src/internal/server/nfqueue/iptables.go b/src/internal/server/nfqueue/iptables.go index 5dba699..190fd06 100644 --- a/src/internal/server/nfqueue/iptables.go +++ b/src/internal/server/nfqueue/iptables.go @@ -5,6 +5,7 @@ package nfqueue import ( "strconv" "strings" + "time" "github.com/coreos/go-iptables/iptables" "github.com/sunbk201/ua3f/internal/netfilter" @@ -31,6 +32,10 @@ func (s *Server) iptSetup() error { if err != nil { return err } + err = s.IptSetSkipIP() + if err != nil { + return err + } err = ipt.NewChain(table, chain) if err != nil { @@ -63,9 +68,26 @@ func (s *Server) iptCleanup() error { ipt.Delete(table, jumpPoint, JumpChain...) ipt.ClearAndDeleteChain(table, chain) s.IptDeleteLanIP() + s.IptDeleteSkipIP() return nil } +func (s *Server) IptWatch() { + go func() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.IptAddSkipDomains() + case ip := <-s.SkipIpChan: + s.IptAddSkipIP(ip.String()) + } + } + }() +} + func (s *Server) IptSetNfqueue(ipt *iptables.IPTables) error { err := ipt.Append(table, chain, netfilter.IptRuleIgnoreReply...) if err != nil { @@ -75,6 +97,10 @@ func (s *Server) IptSetNfqueue(ipt *iptables.IPTables) error { if err != nil { return err } + err = ipt.Append(table, chain, netfilter.IptRuleIgnoreIP...) + if err != nil { + return err + } err = ipt.Append(table, chain, netfilter.IptRuleIgnorePorts...) if err != nil { return err diff --git a/src/internal/server/nfqueue/nfqueue_linux.go b/src/internal/server/nfqueue/nfqueue_linux.go index 94b2755..67bad8f 100644 --- a/src/internal/server/nfqueue/nfqueue_linux.go +++ b/src/internal/server/nfqueue/nfqueue_linux.go @@ -5,6 +5,7 @@ package nfqueue import ( "fmt" "log/slog" + "net" "time" nfq "github.com/florianl/go-nfqueue/v2" @@ -32,10 +33,11 @@ type Server struct { func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ - Cfg: cfg, - Rewriter: rw, - Recorder: rc, - Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + Cfg: cfg, + Rewriter: rw, + Recorder: rc, + Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + SkipIpChan: make(chan *net.IP, 512), }, SniffCtMarkLower: 10201, SniffCtMarkUpper: 10216, @@ -53,8 +55,10 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser }, NftSetup: s.nftSetup, NftCleanup: s.nftCleanup, + NftWatch: s.NftWatch, IptSetup: s.iptSetup, IptCleanup: s.iptCleanup, + IptWatch: s.IptWatch, } return s } @@ -87,6 +91,12 @@ func (s *Server) handlePacket(packet *base.Packet) { return } result := s.Rewriter.RewriteTCP(packet.TCP, packet.SrcAddr, packet.DstAddr) + if result.NeedSkip { + select { + case s.SkipIpChan <- &packet.DstIP: + default: + } + } s.sendVerdict(packet, result) } @@ -130,6 +140,10 @@ func (s *Server) sendVerdict(packet *base.Packet, result *rewrite.RewriteResult) } func (s *Server) getNextMark(packet *base.Packet, result *rewrite.RewriteResult) (setMark bool, mark uint32) { + if result.NeedSkip { + return true, s.NotHTTPCtMark + } + mark, found := packet.GetCtMark() if !found { return true, s.SniffCtMarkLower diff --git a/src/internal/server/nfqueue/nftables.go b/src/internal/server/nfqueue/nftables.go index 6a9f259..69b66b5 100644 --- a/src/internal/server/nfqueue/nftables.go +++ b/src/internal/server/nfqueue/nftables.go @@ -5,6 +5,7 @@ package nfqueue import ( "context" "fmt" + "time" "github.com/sunbk201/ua3f/internal/netfilter" "sigs.k8s.io/knftables" @@ -21,6 +22,8 @@ func (s *Server) nftSetup() error { s.NftSetLanIP(tx, s.Nftable) s.NftSetLanIP6(tx, s.Nftable) + s.NftSetSkipIP(tx, s.Nftable) + s.NftSetSkipIP6(tx, s.Nftable) s.NftSetNfqueue(tx, s.Nftable) if err := nft.Run(context.TODO(), tx); err != nil { @@ -44,6 +47,28 @@ func (s *Server) nftCleanup() error { return nil } +func (s *Server) NftWatch() { + go func() { + _ = s.NftAddSkipDomains() + + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _ = s.NftAddSkipDomains() + case ip := <-s.SkipIpChan: + if ip.To4() != nil { + s.NftAddSkipIP(s.Nftable, []string{ip.String()}) + } else { + s.NftAddSkipIP6(s.Nftable, []string{ip.String()}) + } + } + } + }() +} + func (s *Server) NftSetNfqueue(tx *knftables.Transaction, table *knftables.Table) { chain := &knftables.Chain{ Name: "POSTROUTING", @@ -74,6 +99,16 @@ func (s *Server) NftSetNfqueue(tx *knftables.Transaction, table *knftables.Table Rule: netfilter.NftRuleIgnoreLAN6, }) + tx.Add(&knftables.Rule{ + Chain: chain.Name, + Rule: netfilter.NftRuleIgnoreIP, + }) + + tx.Add(&knftables.Rule{ + Chain: chain.Name, + Rule: netfilter.NftRuleIgnoreIP6, + }) + tx.Add(&knftables.Rule{ Chain: chain.Name, Rule: netfilter.NftRuleIgnorePorts, diff --git a/src/internal/server/redirect/iptables.go b/src/internal/server/redirect/iptables.go index 4ba765a..bc7abbb 100644 --- a/src/internal/server/redirect/iptables.go +++ b/src/internal/server/redirect/iptables.go @@ -4,6 +4,7 @@ package redirect import ( "strconv" + "time" "github.com/coreos/go-iptables/iptables" "github.com/sunbk201/ua3f/internal/netfilter" @@ -30,6 +31,10 @@ func (s *Server) iptSetup() error { if err != nil { return err } + err = s.IptSetSkipIP() + if err != nil { + return err + } err = ipt.NewChain(table, chain) if err != nil { @@ -56,9 +61,26 @@ func (s *Server) iptCleanup() error { ipt.Delete(table, jumpPoint, JumpChain...) ipt.ClearAndDeleteChain(table, chain) s.IptDeleteLanIP() + s.IptDeleteSkipIP() return nil } +func (s *Server) IptWatch() { + go func() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.IptAddSkipDomains() + case ip := <-s.SkipIpChan: + s.IptAddSkipIP(ip.String()) + } + } + }() +} + func (s *Server) IptSetRedirect(ipt *iptables.IPTables) error { err := ipt.Append(table, chain, netfilter.IptRuleIgnoreBrLAN...) if err != nil { @@ -72,6 +94,10 @@ func (s *Server) IptSetRedirect(ipt *iptables.IPTables) error { if err != nil { return err } + err = ipt.Append(table, chain, netfilter.IptRuleIgnoreIP...) + if err != nil { + return err + } err = ipt.Append(table, chain, netfilter.IptRuleIgnorePorts...) if err != nil { return err diff --git a/src/internal/server/redirect/nftables.go b/src/internal/server/redirect/nftables.go index c8c402d..00f0260 100644 --- a/src/internal/server/redirect/nftables.go +++ b/src/internal/server/redirect/nftables.go @@ -5,6 +5,7 @@ package redirect import ( "context" "fmt" + "time" "github.com/sunbk201/ua3f/internal/netfilter" "sigs.k8s.io/knftables" @@ -21,6 +22,8 @@ func (s *Server) nftSetup() error { s.NftSetLanIP(tx, s.Nftable) s.NftSetLanIP6(tx, s.Nftable) + s.NftSetSkipIP(tx, s.Nftable) + s.NftSetSkipIP6(tx, s.Nftable) s.NftSetRedirect(tx, s.Nftable) if err := nft.Run(context.TODO(), tx); err != nil { @@ -44,6 +47,28 @@ func (s *Server) nftCleanup() error { return nil } +func (s *Server) NftWatch() { + go func() { + _ = s.NftAddSkipDomains() + + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _ = s.NftAddSkipDomains() + case ip := <-s.SkipIpChan: + if ip.To4() != nil { + s.NftAddSkipIP(s.Nftable, []string{ip.String()}) + } else { + s.NftAddSkipIP6(s.Nftable, []string{ip.String()}) + } + } + } + }() +} + func (s *Server) NftSetRedirect(tx *knftables.Transaction, table *knftables.Table) { chain := &knftables.Chain{ Name: "PREROUTING", @@ -79,6 +104,16 @@ func (s *Server) NftSetRedirect(tx *knftables.Transaction, table *knftables.Tabl Rule: netfilter.NftRuleIgnoreLAN6, }) + tx.Add(&knftables.Rule{ + Chain: chain.Name, + Rule: netfilter.NftRuleIgnoreIP, + }) + + tx.Add(&knftables.Rule{ + Chain: chain.Name, + Rule: netfilter.NftRuleIgnoreIP6, + }) + tx.Add(&knftables.Rule{ Chain: chain.Name, Rule: netfilter.NftRuleIgnorePorts, diff --git a/src/internal/server/redirect/redirect_linux.go b/src/internal/server/redirect/redirect_linux.go index 70b24da..b91a396 100644 --- a/src/internal/server/redirect/redirect_linux.go +++ b/src/internal/server/redirect/redirect_linux.go @@ -29,10 +29,11 @@ type Server struct { func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ - Cfg: cfg, - Rewriter: rw, - Recorder: rc, - Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + Cfg: cfg, + Rewriter: rw, + Recorder: rc, + Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + SkipIpChan: make(chan *net.IP, 512), }, so_mark: netfilter.SO_MARK, } @@ -43,8 +44,10 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser }, NftSetup: s.nftSetup, NftCleanup: s.nftCleanup, + NftWatch: s.NftWatch, IptSetup: s.iptSetup, IptCleanup: s.iptCleanup, + IptWatch: s.IptWatch, } return s } diff --git a/src/internal/server/tproxy/iptables.go b/src/internal/server/tproxy/iptables.go index 6bdfcd6..e5ac5ad 100644 --- a/src/internal/server/tproxy/iptables.go +++ b/src/internal/server/tproxy/iptables.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/coreos/go-iptables/iptables" "github.com/sunbk201/ua3f/internal/netfilter" @@ -57,6 +58,10 @@ func (s *Server) iptSetup() error { if err != nil { return err } + err = s.IptSetSkipIP() + if err != nil { + return err + } if netfilter.SIDECAR == netfilter.SC { err = ipt.NewChain(table, chainSidecar) @@ -112,6 +117,22 @@ func (s *Server) iptCleanup() error { return nil } +func (s *Server) IptWatch() { + go func() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.IptAddSkipDomains() + case ip := <-s.SkipIpChan: + s.IptAddSkipIP(ip.String()) + } + } + }() +} + func (s *Server) IptSetTproxy(ipt *iptables.IPTables) error { if netfilter.SIDECAR == netfilter.SC { var RuleSidecar = []string{ @@ -138,6 +159,10 @@ func (s *Server) IptSetTproxy(ipt *iptables.IPTables) error { if err != nil { return err } + err = ipt.Append(table, chainPre, netfilter.IptRuleIgnoreIP...) + if err != nil { + return err + } err = ipt.Append(table, chainPre, netfilter.IptRuleIgnorePorts...) if err != nil { return err @@ -210,6 +235,10 @@ func (s *Server) IptSetTproxy(ipt *iptables.IPTables) error { if err != nil { return err } + err = ipt.Append(table, chainOut, netfilter.IptRuleIgnoreIP...) + if err != nil { + return err + } err = ipt.Append(table, chainOut, netfilter.IptRuleIgnorePorts...) if err != nil { return err diff --git a/src/internal/server/tproxy/nftables.go b/src/internal/server/tproxy/nftables.go index 42e20c2..61320a4 100644 --- a/src/internal/server/tproxy/nftables.go +++ b/src/internal/server/tproxy/nftables.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "syscall" + "time" "github.com/sunbk201/ua3f/internal/netfilter" "sigs.k8s.io/knftables" @@ -27,6 +28,8 @@ func (s *Server) nftSetup() error { s.NftSetLanIP(tx, s.Nftable) s.NftSetLanIP6(tx, s.Nftable) + s.NftSetSkipIP(tx, s.Nftable) + s.NftSetSkipIP6(tx, s.Nftable) s.NftSetTproxy(tx, s.Nftable) if err := nft.Run(context.TODO(), tx); err != nil { @@ -51,6 +54,28 @@ func (s *Server) nftCleanup() error { return nil } +func (s *Server) NftWatch() { + go func() { + _ = s.NftAddSkipDomains() + + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _ = s.NftAddSkipDomains() + case ip := <-s.SkipIpChan: + if ip.To4() != nil { + s.NftAddSkipIP(s.Nftable, []string{ip.String()}) + } else { + s.NftAddSkipIP6(s.Nftable, []string{ip.String()}) + } + } + } + }() +} + func (s *Server) NftSetTproxy(tx *knftables.Transaction, table *knftables.Table) { if netfilter.SIDECAR == netfilter.SC { sidecar := &knftables.Chain{ @@ -107,6 +132,16 @@ func (s *Server) NftSetTproxy(tx *knftables.Transaction, table *knftables.Table) Rule: netfilter.NftRuleIgnoreLAN6, }) + tx.Add(&knftables.Rule{ + Chain: prerouting.Name, + Rule: netfilter.NftRuleIgnoreIP, + }) + + tx.Add(&knftables.Rule{ + Chain: prerouting.Name, + Rule: netfilter.NftRuleIgnoreIP6, + }) + tx.Add(&knftables.Rule{ Chain: prerouting.Name, Rule: netfilter.NftRuleIgnorePorts, @@ -186,6 +221,16 @@ func (s *Server) NftSetTproxy(tx *knftables.Transaction, table *knftables.Table) Rule: netfilter.NftRuleIgnoreLAN6, }) + tx.Add(&knftables.Rule{ + Chain: output.Name, + Rule: netfilter.NftRuleIgnoreIP, + }) + + tx.Add(&knftables.Rule{ + Chain: output.Name, + Rule: netfilter.NftRuleIgnoreIP6, + }) + tx.Add(&knftables.Rule{ Chain: output.Name, Rule: netfilter.NftRuleIgnorePorts, diff --git a/src/internal/server/tproxy/tproxy_linux.go b/src/internal/server/tproxy/tproxy_linux.go index 871bd90..e15a49b 100644 --- a/src/internal/server/tproxy/tproxy_linux.go +++ b/src/internal/server/tproxy/tproxy_linux.go @@ -35,10 +35,11 @@ type Server struct { func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Server { s := &Server{ Server: base.Server{ - Cfg: cfg, - Rewriter: rw, - Recorder: rc, - Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + Cfg: cfg, + Rewriter: rw, + Recorder: rc, + Cache: expirable.NewLRU[string, struct{}](1024, nil, 30*time.Minute), + SkipIpChan: make(chan *net.IP, 512), }, so_mark: netfilter.SO_MARK, tproxyFwMark: "0x1c9", @@ -55,8 +56,10 @@ func New(cfg *config.Config, rw *rewrite.Rewriter, rc *statistics.Recorder) *Ser }, NftSetup: s.nftSetup, NftCleanup: s.nftCleanup, + NftWatch: s.NftWatch, IptSetup: s.iptSetup, IptCleanup: s.iptCleanup, + IptWatch: s.IptWatch, } return s }