Files
gluetun/internal/firewall/nftables/input.go
T
Quentin McGaw cd9ba54b37 wip
2026-02-28 22:38:52 +00:00

171 lines
3.9 KiB
Go

package nftables
import (
"context"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(intf + "\x00"),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
const tcp, udp uint8 = 6, 17
protocols := []uint8{tcp, udp}
for _, protocol := range protocols {
const maxExprsLen = 7
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
if !remove {
conn.AddRule(rule)
f.rules = append(f.rules, rule)
continue
}
err = f.deleteRule(conn, rule)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
}
err = conn.Flush()
if err != nil {
f.rules = f.rules[:len(f.rules)-len(protocols)]
return fmt.Errorf("flushing: %w", err)
}
return nil
}
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
const maxExprsLen = 5
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
var payloadOffset uint32
if subnet.Addr().Is4() {
payloadOffset = 16
} else {
payloadOffset = 24
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: payloadOffset,
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: subnet.Addr().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}