From cd9ba54b3740f8884df534fab162df858474d90c Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 28 Feb 2026 22:38:52 +0000 Subject: [PATCH] wip --- go.mod | 4 +- go.sum | 10 +- internal/firewall/nftables/atomic.go | 99 +++++++++++++ internal/firewall/nftables/basechains.go | 50 +++++++ internal/firewall/nftables/conntrack.go | 61 ++++++++ internal/firewall/nftables/delete.go | 27 ++++ internal/firewall/nftables/filter.go | 38 +++++ internal/firewall/nftables/firewall.go | 22 +++ internal/firewall/nftables/input.go | 170 +++++++++++++++++++++++ internal/firewall/nftables/interfaces.go | 5 + internal/firewall/nftables/output.go | 78 +++++++++++ internal/firewall/nftables/support.go | 12 ++ 12 files changed, 570 insertions(+), 6 deletions(-) create mode 100644 internal/firewall/nftables/atomic.go create mode 100644 internal/firewall/nftables/basechains.go create mode 100644 internal/firewall/nftables/conntrack.go create mode 100644 internal/firewall/nftables/delete.go create mode 100644 internal/firewall/nftables/filter.go create mode 100644 internal/firewall/nftables/firewall.go create mode 100644 internal/firewall/nftables/input.go create mode 100644 internal/firewall/nftables/interfaces.go create mode 100644 internal/firewall/nftables/output.go create mode 100644 internal/firewall/nftables/support.go diff --git a/go.mod b/go.mod index 5bf281ce..ca2ccd1c 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,12 @@ require ( github.com/breml/rootcerts v0.3.4 github.com/fatih/color v1.18.0 github.com/golang/mock v1.6.0 + github.com/google/nftables v0.3.0 github.com/jsimonetti/rtnetlink v1.4.2 github.com/klauspost/compress v1.18.1 github.com/klauspost/pgzip v1.2.6 github.com/mdlayher/genetlink v1.3.2 - github.com/mdlayher/netlink v1.7.2 + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 github.com/pelletier/go-toml/v2 v2.2.4 github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205 github.com/qdm12/gosettings v0.4.4 @@ -42,7 +43,6 @@ require ( github.com/cronokirby/saferith v0.33.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/socket v0.5.1 // indirect diff --git a/go.sum b/go.sum index b30ac9ab..c0a1741e 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90= github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= @@ -49,8 +49,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= @@ -99,6 +99,8 @@ github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E= github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/internal/firewall/nftables/atomic.go b/internal/firewall/nftables/atomic.go new file mode 100644 index 00000000..cd27b06f --- /dev/null +++ b/internal/firewall/nftables/atomic.go @@ -0,0 +1,99 @@ +package nftables + +import ( + "context" + "fmt" + + "github.com/google/nftables" +) + +// SaveAndRestore saves the current nftables tree and returns a restore function that +// can be called to restore the saved tree. +func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return nil, fmt.Errorf("creating nftables connection: %w", err) + } + tables, err := saveTables(conn) + if err != nil { + return nil, fmt.Errorf("saving nftables state: %w", err) + } + return func(_ context.Context) { + conn, err := nftables.New() + if err != nil { + f.logger.Warnf("creating nftables connection for restore: %s", err) + return + } + err = restoreTables(conn, tables) + if err != nil { + f.logger.Warnf("restoring nftables state: %s", err) + } + }, nil +} + +type savedTable struct { + table *nftables.Table + chains []savedChain +} + +type savedChain struct { + chain *nftables.Chain + rules []*nftables.Rule +} + +func saveTables(conn *nftables.Conn) ([]savedTable, error) { + tables, err := conn.ListTables() + if err != nil { + return nil, err + } + + savedTables := make([]savedTable, len(tables)) + for i, table := range tables { + savedTables[i].table = table + + chains, err := conn.ListChains() + if err != nil { + return nil, err + } + + for _, chain := range chains { + if chain.Table.Name != table.Name || + chain.Table.Family != table.Family { + continue + } + rules, err := conn.GetRules(table, chain) + if err != nil { + return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err) + } + savedChain := savedChain{chain: chain, rules: rules} + savedTables[i].chains = append(savedTables[i].chains, savedChain) + } + } + + return savedTables, nil +} + +func restoreTables(conn *nftables.Conn, savedTables []savedTable) error { + conn.FlushRuleset() + + for _, savedTable := range savedTables { + table := conn.AddTable(savedTable.table) + for _, savedChain := range savedTable.chains { + // Make the [nftables.Chain.Table] points to the new [nftables.Table] + // created in this connection. + savedChain.chain.Table = table + savedChain.chain = conn.AddChain(savedChain.chain) + + for _, rule := range savedChain.rules { + rule.Table = table + rule.Chain = savedChain.chain + conn.AddRule(rule) + } + } + } + + return conn.Flush() +} diff --git a/internal/firewall/nftables/basechains.go b/internal/firewall/nftables/basechains.go new file mode 100644 index 00000000..13646aa1 --- /dev/null +++ b/internal/firewall/nftables/basechains.go @@ -0,0 +1,50 @@ +package nftables + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/google/nftables" +) + +var ErrPolicyUnknown = errors.New("unknown policy") + +// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT) +// for the filter table to the given policy (accept or drop). +func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + var chainPolicy nftables.ChainPolicy + switch strings.ToLower(policy) { + case "accept": + chainPolicy = nftables.ChainPolicyAccept + case "drop": + chainPolicy = nftables.ChainPolicyDrop + default: + return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy) + } + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + _, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn) + inputChain.Policy = &chainPolicy + forwardChain.Policy = &chainPolicy + outputChain.Policy = &chainPolicy + + conn.AddChain(inputChain) + conn.AddChain(forwardChain) + conn.AddChain(outputChain) + + err = conn.Flush() + if err != nil { + return fmt.Errorf("flushing nftables changes: %w", err) + } + + return nil +} diff --git a/internal/firewall/nftables/conntrack.go b/internal/firewall/nftables/conntrack.go new file mode 100644 index 00000000..224fea92 --- /dev/null +++ b/internal/firewall/nftables/conntrack.go @@ -0,0 +1,61 @@ +package nftables + +import ( + "context" + "fmt" + + "github.com/google/nftables" + "github.com/google/nftables/expr" +) + +func (f *Firewall) AcceptEstablishedRelatedTraffic(_ context.Context) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + table, inputChain, _, outputChain := setupFilterWithBaseChains(conn) + + ctStateExprs := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, //nolint:mnd + Mask: []byte{byte(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), 0x00, 0x00, 0x00}, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: inputChain, + Exprs: ctStateExprs, + }) + + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: outputChain, + Exprs: ctStateExprs, + }) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flushing: %w", err) + } + + return nil +} diff --git a/internal/firewall/nftables/delete.go b/internal/firewall/nftables/delete.go new file mode 100644 index 00000000..003632b0 --- /dev/null +++ b/internal/firewall/nftables/delete.go @@ -0,0 +1,27 @@ +package nftables + +import ( + "errors" + "fmt" + "reflect" + + "github.com/google/nftables" +) + +var errRuleToDeleteNotFound = errors.New("rule not found for removal") + +func (f *Firewall) deleteRule(conn *nftables.Conn, rule *nftables.Rule) error { + for i, existing := range f.rules { + if !reflect.DeepEqual(existing, rule) { + continue + } + err := conn.DelRule(existing) + if err != nil { + return fmt.Errorf("deleting rule: %w", err) + } + f.rules[i], f.rules[len(f.rules)-1] = f.rules[len(f.rules)-1], f.rules[i] + f.rules = f.rules[:len(f.rules)-1] + return nil + } + return fmt.Errorf("%w: %#v", errRuleToDeleteNotFound, rule) +} diff --git a/internal/firewall/nftables/filter.go b/internal/firewall/nftables/filter.go new file mode 100644 index 00000000..b1fa6037 --- /dev/null +++ b/internal/firewall/nftables/filter.go @@ -0,0 +1,38 @@ +package nftables + +import "github.com/google/nftables" + +func setupFilterWithBaseChains(conn *nftables.Conn) (table *nftables.Table, + inputChain, forwardChain, outputChain *nftables.Chain, +) { + table = conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyINet, + Name: "filter", + }) + + inputChain = conn.AddChain(&nftables.Chain{ + Name: "input", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + forwardChain = conn.AddChain(&nftables.Chain{ + Name: "forward", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + outputChain = conn.AddChain(&nftables.Chain{ + Name: "output", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityFilter, + }) + + return table, inputChain, forwardChain, outputChain +} diff --git a/internal/firewall/nftables/firewall.go b/internal/firewall/nftables/firewall.go new file mode 100644 index 00000000..a509c1f5 --- /dev/null +++ b/internal/firewall/nftables/firewall.go @@ -0,0 +1,22 @@ +package nftables + +import ( + "sync" + + "github.com/google/nftables" +) + +type Firewall struct { + logger Logger + + // rules are only rules added and tracked for later removal. + // Not all rules added are tracked for removal. + rules []*nftables.Rule + mutex sync.Mutex +} + +func New(logger Logger) *Firewall { + return &Firewall{ + logger: logger, + } +} diff --git a/internal/firewall/nftables/input.go b/internal/firewall/nftables/input.go new file mode 100644 index 00000000..3f74fa2a --- /dev/null +++ b/internal/firewall/nftables/input.go @@ -0,0 +1,170 @@ +package nftables + +import ( + "context" + "fmt" + "net/netip" + + "github.com/google/nftables" + "github.com/google/nftables/expr" +) + +func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + table, inputChain, _, _ := setupFilterWithBaseChains(conn) + + rule := &nftables.Rule{ + Table: table, + Chain: inputChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(intf + "\x00"), + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + + conn.AddRule(rule) + + err = conn.Flush() + if err != nil { + return fmt.Errorf("flushing: %w", err) + } + + return nil +} + +// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP +// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter. +// If remove is true, the rule is removed instead of added. This is used for port forwarding, with +// intf set to the VPN tunnel interface. +func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + table, inputChain, _, _ := setupFilterWithBaseChains(conn) + portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd + const tcp, udp uint8 = 6, 17 + protocols := []uint8{tcp, udp} + + for _, protocol := range protocols { + const maxExprsLen = 7 + exprs := make([]expr.Any, 0, maxExprsLen) + if intf != "" && intf != "*" { + exprs = append(exprs, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")}, + ) + } + exprs = append(exprs, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes}, + &expr.Verdict{Kind: expr.VerdictAccept}, + ) + + rule := &nftables.Rule{ + Table: table, + Chain: inputChain, + Exprs: exprs, + } + + if !remove { + conn.AddRule(rule) + f.rules = append(f.rules, rule) + continue + } + err = f.deleteRule(conn, rule) + if err != nil { + return fmt.Errorf("deleting rule: %w", err) + } + } + + err = conn.Flush() + if err != nil { + f.rules = f.rules[:len(f.rules)-len(protocols)] + return fmt.Errorf("flushing: %w", err) + } + + return nil +} + +func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + table, inputChain, _, _ := setupFilterWithBaseChains(conn) + + const maxExprsLen = 5 + exprs := make([]expr.Any, 0, maxExprsLen) + + if intf != "" && intf != "*" { + exprs = append(exprs, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")}, + ) + } + + var payloadOffset uint32 + if subnet.Addr().Is4() { + payloadOffset = 16 + } else { + payloadOffset = 24 + } + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: payloadOffset, + Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: subnet.Addr().AsSlice(), + }, + &expr.Verdict{Kind: expr.VerdictAccept}, + ) + + rule := &nftables.Rule{ + Table: table, + Chain: inputChain, + Exprs: exprs, + } + + conn.AddRule(rule) + + err = conn.Flush() + if err != nil { + return fmt.Errorf("flushing: %w", err) + } + + return nil +} diff --git a/internal/firewall/nftables/interfaces.go b/internal/firewall/nftables/interfaces.go new file mode 100644 index 00000000..03f9974a --- /dev/null +++ b/internal/firewall/nftables/interfaces.go @@ -0,0 +1,5 @@ +package nftables + +type Logger interface { + Warnf(format string, args ...any) +} diff --git a/internal/firewall/nftables/output.go b/internal/firewall/nftables/output.go new file mode 100644 index 00000000..7d71f9c4 --- /dev/null +++ b/internal/firewall/nftables/output.go @@ -0,0 +1,78 @@ +package nftables + +import ( + "context" + "fmt" + + "github.com/google/nftables" + "github.com/google/nftables/expr" +) + +func (f *Firewall) AcceptIpv6MulticastOutput(_ context.Context, intf string) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + conn, err := nftables.New() + if err != nil { + return fmt.Errorf("creating nftables connection: %w", err) + } + + table, _, _, outputChain := setupFilterWithBaseChains(conn) + + const maxExprsLen = 6 + exprs := make([]expr.Any, 0, maxExprsLen) + + if intf != "" && intf != "*" { + exprs = append(exprs, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")}, + ) + } + + // ff02::1:ff00:0/104 mask is 13 bytes of 0xff + mask := []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, + } //nolint:mnd + addr := []byte{ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00, + } //nolint:mnd + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 24, // IPv6 Destination Address offset //nolint:mnd + Len: 16, //nolint:mnd + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 16, //nolint:mnd + Mask: mask, + Xor: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //nolint:mnd + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: addr, + }, + &expr.Verdict{Kind: expr.VerdictAccept}, + ) + + rule := &nftables.Rule{ + Table: table, + Chain: outputChain, + Exprs: exprs, + } + + conn.AddRule(rule) + + err = conn.Flush() + if err != nil { + return fmt.Errorf("flushing: %w", err) + } + + return nil +} diff --git a/internal/firewall/nftables/support.go b/internal/firewall/nftables/support.go new file mode 100644 index 00000000..eb706c01 --- /dev/null +++ b/internal/firewall/nftables/support.go @@ -0,0 +1,12 @@ +package nftables + +import "github.com/google/nftables" + +func IsSupported() bool { + conn, err := nftables.New() + if err != nil { + return false + } + _, err = conn.ListTable("filter") + return err == nil +}