Compare commits

..

6 Commits

Author SHA1 Message Date
Quentin McGaw cd9ba54b37 wip 2026-02-28 22:38:52 +00:00
Quentin McGaw 781e74f77a chore: merge iptables SetIPv4AllPolicies and SetIPv6AllPolicies together 2026-02-28 15:25:15 +00:00
Quentin McGaw fa0941a529 add nftables to dev container 2026-02-28 15:24:37 +00:00
Quentin McGaw e87d915f15 chore(firewall/iptables): modprobe and cache support for xt_mark and nf_tables 2026-02-28 15:23:30 +00:00
Quentin McGaw ec24ffdfd8 hotfix(firewall): save and restore behavior fixed
- restore if IPv4 set all policies fails
- fix deadlock when using iptables custom rules
- fix setting ipv6 rules when running runMixedIptablesInstruction
2026-02-28 14:37:58 +00:00
dependabot[bot] b9d49e0661 Chore(deps): Bump github.com/breml/rootcerts from 0.3.3 to 0.3.4 (#3128) 2026-02-27 02:16:31 +01:00
32 changed files with 661 additions and 444 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl tcpdump iptables
RUN apk add wireguard-tools htop openssl tcpdump iptables nftables
+5 -1
View File
@@ -227,7 +227,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
firewallLogger.Patch(log.SetLevel(log.LevelDebug))
}
firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder,
netLinker, defaultRoutes, localNetworks)
defaultRoutes, localNetworks)
if err != nil {
return err
}
@@ -237,6 +237,10 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
if err != nil {
return err
}
err = netLinker.FlushConntrack()
if err != nil {
logger.Warnf("flushing conntrack failed: %s", err)
}
}
// TODO run this in a loop or in openvpn to reload from file without restarting
+3 -3
View File
@@ -4,14 +4,15 @@ go 1.25.0
require (
github.com/ProtonMail/go-srp v0.0.7
github.com/breml/rootcerts v0.3.3
github.com/breml/rootcerts v0.3.4
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/google/nftables v0.3.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/mdlayher/netlink v1.7.2
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260216151239-36b3306f2205
github.com/qdm12/gosettings v0.4.4
@@ -42,7 +43,6 @@ require (
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
+8 -6
View File
@@ -8,8 +8,8 @@ github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1J
github.com/ProtonMail/go-srp v0.0.7/go.mod h1:giCp+7qRnMIcCvI6V6U3S1lDDXDQYx2ewJ6F/9wdlJk=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw=
github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/breml/rootcerts v0.3.4 h1:9i7WNl/ctd9OEAOaTfLy//Wrlfxq/tRQ7v4okYFN9Ys=
github.com/breml/rootcerts v0.3.4/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -30,8 +30,8 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
@@ -49,8 +49,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
@@ -99,6 +99,8 @@ github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ
github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
+4 -13
View File
@@ -45,20 +45,16 @@ func (c *Config) enable(ctx context.Context) (err error) {
return fmt.Errorf("saving firewall rules: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
if err = c.impl.SetIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
defer func() {
if err != nil {
c.restore(context.Background())
}
}()
if err = c.impl.SetBaseChainsPolicy(ctx, "DROP"); err != nil {
return err
}
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
return err
@@ -69,11 +65,6 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.flushExistingConnections(ctx)
if err != nil {
return fmt.Errorf("flushing existing connections: %w", err)
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err
}
+2 -4
View File
@@ -13,7 +13,6 @@ import (
type Config struct {
runner CmdRunner
netlinker Netlinker
logger Logger
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
@@ -36,8 +35,8 @@ type Config struct {
// NewConfig creates a new Config instance and returns an error
// if no iptables implementation is available.
func NewConfig(ctx context.Context, logger Logger,
runner CmdRunner, netlinker Netlinker,
defaultRoutes []routing.DefaultRoute, localNetworks []routing.LocalNetwork,
runner CmdRunner, defaultRoutes []routing.DefaultRoute,
localNetworks []routing.LocalNetwork,
) (config *Config, err error) {
impl, err := iptables.New(ctx, runner, logger)
if err != nil {
@@ -46,7 +45,6 @@ func NewConfig(ctx context.Context, logger Logger,
return &Config{
runner: runner,
netlinker: netlinker,
logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}),
// Obtained from routing
-74
View File
@@ -1,74 +0,0 @@
package firewall
import (
"context"
"errors"
"fmt"
"time"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/netlink"
)
func (c *Config) flushExistingConnections(ctx context.Context) error {
tries := []struct {
name string
f func(ctx context.Context) error
}{
{name: "flushing conntrack", f: func(_ context.Context) error {
return c.netlinker.FlushConntrack()
}},
{name: "marking and filtering unmarked packets", f: c.impl.AcceptOutputPublicOnlyNewTraffic},
{name: "rejecting connections for one second", f: c.rejectOutputTrafficTemporarily},
{name: "dropping connections for one second", f: c.dropOutputTrafficTemporarily},
}
errs := make([]error, 0, len(tries))
for i, try := range tries {
if i > 0 {
c.logger.Debugf("falling back to %s because %s failed: %s", try.name, tries[i-1].name, errs[i-1])
}
err := try.f(ctx)
if err == nil {
return nil
}
err = fmt.Errorf("%s: %w", try.name, err)
if !errors.Is(err, iptables.ErrKernelModuleMissing) && !errors.Is(err, netlink.ErrConntrackNetlinkNotSupported) {
return err
}
errs = append(errs, err)
}
return fmt.Errorf("all tries failed: %v", errs) //nolint:err113
}
func (c *Config) rejectOutputTrafficTemporarily(ctx context.Context) error {
return setupThenRevert(ctx, c.impl.RejectOutputPublicTraffic)
}
func (c *Config) dropOutputTrafficTemporarily(ctx context.Context) error {
return setupThenRevert(ctx, c.impl.DropOutputPublicTraffic)
}
// setupThenRevert is a helper function to run a setup function that takes a remove boolean argument,
// and then run the same function with remove set to true after one second or when the context is canceled,
// whichever comes first.
func setupThenRevert(ctx context.Context, f func(ctx context.Context, remove bool) error) error {
remove := false
err := f(ctx, remove)
if err != nil {
return fmt.Errorf("setting up: %w", err)
}
timer := time.NewTimer(time.Second)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
}
remove = true
// Use [context.Background] to make sure this is removed, even if the context
// passed to this function is canceled.
err = f(context.Background(), remove)
if err != nil {
return fmt.Errorf("reverting: %w", err)
}
return nil
}
+2 -11
View File
@@ -14,23 +14,15 @@ type CmdRunner interface {
type Logger interface {
Debug(s string)
Debugf(format string, args ...any)
Info(s string)
Warn(s string)
Error(s string)
}
type Netlinker interface {
FlushConntrack() error
}
type firewallImpl interface { //nolint:interfacebloat
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error
RejectOutputPublicTraffic(ctx context.Context, remove bool) error
DropOutputPublicTraffic(ctx context.Context, remove bool) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
@@ -42,8 +34,7 @@ type firewallImpl interface { //nolint:interfacebloat
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string) error
SetIPv4AllPolicies(ctx context.Context, policy string) error
SetIPv6AllPolicies(ctx context.Context, policy string) error
SetBaseChainsPolicy(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error)
Version(ctx context.Context) (version string, err error)
+15 -3
View File
@@ -2,11 +2,10 @@ package iptables
import (
"context"
"errors"
"sync"
)
var ErrKernelModuleMissing = errors.New("kernel module is missing for this operation")
"github.com/qdm12/gluetun/internal/mod"
)
type Config struct {
runner CmdRunner
@@ -17,6 +16,8 @@ type Config struct {
// Fixed state
ipTables string
ip6Tables string
nftables bool
xtMark bool
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
@@ -30,10 +31,21 @@ func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error)
return nil, err
}
modules := map[string]bool{
"xt_mark": false,
"nf_tables": false,
}
for module := range modules {
err := mod.Probe(module)
modules[module] = err == nil
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
nftables: modules["nf_tables"],
xtMark: modules["xt_mark"],
}, nil
}
-18
View File
@@ -76,26 +76,8 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
if strings.Contains(output, "missing kernel module") {
err = ErrKernelModuleMissing
}
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
}
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
+5 -137
View File
@@ -92,22 +92,20 @@ func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction s
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
if strings.Contains(output, "missing kernel module") {
err = ErrKernelModuleMissing
}
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
}
return nil
}
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
func (c *Config) SetBaseChainsPolicy(ctx context.Context, policy string) error {
policy = strings.ToUpper(policy)
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runIptablesInstructions(ctx, []string{
return c.runMixedIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
@@ -150,136 +148,6 @@ func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
})
}
// AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept
// established or related packets with this mark only. This effectively forces
// previously established or related traffic to be blocked.
// If remove is true, the rules are removed instead of appended.
// If the relevant kernel modules are not available, it returns an error indicating
// which kernel module is missing.
func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error {
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
appendToBoth := func(instruction string) {
ipv4Instructions = append(ipv4Instructions, instruction)
ipv6Instructions = append(ipv6Instructions, instruction)
}
// Mark new connections with mark 0x567
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate NEW -j CONNMARK --set-mark 0x567")
// Drop related/established connections that made it through; marked connections would
// be directly accepted by the first rule in the OUTPUT chain (see below)
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j DROP")
// Set the PUBLIC_ONLY chain as the second rule in the OUTPUT chain, so that it is evaluated
// after the accept rule below, for performance reasons.
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
appendToBoth("-I OUTPUT -m conntrack --ctstate RELATED,ESTABLISHED -m connmark --mark 0x567 -j ACCEPT")
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions)
if err != nil {
restore(ctx)
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions)
if err != nil {
restore(ctx)
return err
}
return nil
}
func (c *Config) RejectOutputPublicTraffic(ctx context.Context, remove bool) error {
return c.targetOutputPublicTraffic(ctx, "REJECT", remove)
}
func (c *Config) DropOutputPublicTraffic(ctx context.Context, remove bool) error {
return c.targetOutputPublicTraffic(ctx, "DROP", remove)
}
func (c *Config) targetOutputPublicTraffic(ctx context.Context, target string, remove bool) error {
removeInstructions := []string{
"-D OUTPUT -j PUBLIC_ONLY",
"-F PUBLIC_ONLY",
"-X PUBLIC_ONLY",
}
if remove {
return c.runMixedIptablesInstructions(ctx, removeInstructions)
}
ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions()
appendToBoth := func(instruction string) {
ipv4Instructions = append(ipv4Instructions, instruction)
ipv6Instructions = append(ipv6Instructions, instruction)
}
if target == "REJECT" {
// Block TCP by sending back TCP RST packets.
appendToBoth("-A PUBLIC_ONLY -p tcp -m conntrack --ctstate RELATED,ESTABLISHED " +
"-j REJECT --reject-with tcp-reset")
// Block UDP and ICMP, sending back ICMP port unreachable.
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j REJECT")
} else {
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j " + target)
}
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
err := c.runIptablesInstructions(ctx, ipv4Instructions)
if err != nil {
if strings.Contains(err.Error(), " support") {
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
}
}
err = c.runIP6tablesInstructions(ctx, ipv6Instructions)
if err != nil {
_ = c.runIptablesInstructions(ctx, removeInstructions)
if strings.Contains(err.Error(), " support") {
return fmt.Errorf("%w: %w", ErrKernelModuleMissing, err)
}
return err
}
return nil
}
func makeCreatePublicIPChainInstructions() (ipv4Instructions, ipv6Instructions []string) {
ipv4PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("127.0.0.0/8"),
}
ipv6PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("fc00::/7"),
netip.MustParsePrefix("fe80::/10"),
netip.MustParsePrefix("::1/128"),
}
ipv4Instructions = append(ipv4Instructions, "-N PUBLIC_ONLY")
ipv6Instructions = append(ipv6Instructions, "-N PUBLIC_ONLY")
for _, prefix := range ipv4PrivatePrefixes {
ipv4Instructions = append(ipv4Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
for _, prefix := range ipv6PrivatePrefixes {
ipv6Instructions = append(ipv6Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
return ipv4Instructions, ipv6Instructions
}
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
@@ -470,11 +338,11 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
err = c.runIptablesInstructionNoSave(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstruction(ctx, rule)
err = c.runIP6tablesInstructionNoSave(ctx, rule)
}
if err != nil {
restore(ctx)
+1 -1
View File
@@ -34,7 +34,7 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
err = c.runMixedIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
+11 -33
View File
@@ -33,9 +33,6 @@ type chainRule struct {
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint
rejectWith string // for example "tcp-reset", only used for REJECT targets
}
type mark struct {
@@ -222,6 +219,10 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
@@ -292,33 +293,6 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
}
rule.mark = mark
i += consumed
case "reject-with":
i++
rule.rejectWith = optionalFields[i] // for example "tcp-reset"
i++
case "connmark":
i++
connMark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing connmark: %w", err)
}
rule.connMark = connMark
i += consumed
case "CONNMARK":
i++
switch optionalFields[i] {
case "set":
i++
value, err := parseAny32bNumber(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing CONNMARK set value: %w", err)
}
rule.setMark = value
i++
default:
return fmt.Errorf("%w: unexpected %q after CONNMARK",
ErrChainRuleMalformed, optionalFields[i])
}
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
@@ -448,6 +422,8 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
@@ -457,11 +433,13 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
consumed++
}
value, err := parseAny32bNumber(optionalFields[consumed])
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("value malformed: %w", err)
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
}
m.value = value
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
+12 -89
View File
@@ -9,19 +9,9 @@ import (
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
operation operation
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
@@ -35,9 +25,6 @@ type iptablesInstruction struct {
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint // only used for jump CONNMARK --set-mark
rejectWith string // only used for REJECT targets
}
func (i *iptablesInstruction) setDefaults() {
@@ -78,12 +65,6 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case i.mark != rule.mark:
return false
case i.connMark != rule.connMark:
return false
case i.setMark != rule.setMark:
return false
case i.rejectWith != rule.rejectWith:
return false
default:
return true
}
@@ -132,20 +113,13 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.operation = opDelete
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.operation = opAppend
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.append = true
instruction.chain = value
case "-j", "--jump":
subConsumed, err := parseJumpFlag(fields[1:], instruction)
if err != nil {
return 0, fmt.Errorf("parsing jump flag: %w", err)
}
consumed += subConsumed
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match":
@@ -154,11 +128,13 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
n, err := parseAny32bNumber(value)
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", value, err)
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
}
instruction.mark.value = n
instruction.mark.value = uint(value)
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
@@ -196,8 +172,6 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
case "--reject-with":
instruction.rejectWith = value // for example "tcp-reset"
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
}
@@ -208,7 +182,7 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags":
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
@@ -225,34 +199,6 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
}
}
func parseJumpFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
instruction.target = fields[0]
// consumed in the caller already takes fields[0] into account
if instruction.target != "CONNMARK" {
return consumed, nil
}
// consumed already accounts for the "CONNMARK" value
const expectedFields = 3
if len(fields) < expectedFields {
return 0, fmt.Errorf("%w: jump CONNMARK requires at least two additional values",
ErrIptablesCommandMalformed)
}
switch fields[1] {
case "--set-mark":
n, err := parseAny32bNumber(fields[2])
if err != nil {
return 0, fmt.Errorf("parsing connmark mark value %q: %w", fields[2], err)
}
consumed++
instruction.setMark = n
default:
return consumed, fmt.Errorf("%w: unsupported jump CONNMARK with value: %s",
ErrIptablesCommandMalformed, fields[1])
}
consumed++
return consumed, nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
@@ -275,13 +221,6 @@ func parsePort(value string) (port uint16, err error) {
return uint16(portValue), nil
}
func parseAny32bNumber(mark string) (value uint, err error) {
const base = 0 // auto-detect
const bits = 32
n, err := strconv.ParseUint(mark, base, bits)
return uint(n), err
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
@@ -295,30 +234,14 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
// parse it twice.
case "mark":
consumed++
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
switch fields[consumed] {
case "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
case "connmark":
consumed++
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
consumed++
instruction.connMark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match connmark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
+5 -5
View File
@@ -33,9 +33,9 @@ func Test_parseIptablesInstruction(t *testing.T) {
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opAppend,
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opAppend,
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
operation: opDelete,
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
+2 -4
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/netip"
"os"
)
type tcpFlags struct {
@@ -64,7 +63,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("libxt_mark.so module is missing")
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
@@ -74,8 +73,7 @@ func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
if err != nil && errors.Is(err, os.ErrNotExist) {
if !c.nftables && !c.xtMark {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
+99
View File
@@ -0,0 +1,99 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
)
// SaveAndRestore saves the current nftables tree and returns a restore function that
// can be called to restore the saved tree.
func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("creating nftables connection: %w", err)
}
tables, err := saveTables(conn)
if err != nil {
return nil, fmt.Errorf("saving nftables state: %w", err)
}
return func(_ context.Context) {
conn, err := nftables.New()
if err != nil {
f.logger.Warnf("creating nftables connection for restore: %s", err)
return
}
err = restoreTables(conn, tables)
if err != nil {
f.logger.Warnf("restoring nftables state: %s", err)
}
}, nil
}
type savedTable struct {
table *nftables.Table
chains []savedChain
}
type savedChain struct {
chain *nftables.Chain
rules []*nftables.Rule
}
func saveTables(conn *nftables.Conn) ([]savedTable, error) {
tables, err := conn.ListTables()
if err != nil {
return nil, err
}
savedTables := make([]savedTable, len(tables))
for i, table := range tables {
savedTables[i].table = table
chains, err := conn.ListChains()
if err != nil {
return nil, err
}
for _, chain := range chains {
if chain.Table.Name != table.Name ||
chain.Table.Family != table.Family {
continue
}
rules, err := conn.GetRules(table, chain)
if err != nil {
return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err)
}
savedChain := savedChain{chain: chain, rules: rules}
savedTables[i].chains = append(savedTables[i].chains, savedChain)
}
}
return savedTables, nil
}
func restoreTables(conn *nftables.Conn, savedTables []savedTable) error {
conn.FlushRuleset()
for _, savedTable := range savedTables {
table := conn.AddTable(savedTable.table)
for _, savedChain := range savedTable.chains {
// Make the [nftables.Chain.Table] points to the new [nftables.Table]
// created in this connection.
savedChain.chain.Table = table
savedChain.chain = conn.AddChain(savedChain.chain)
for _, rule := range savedChain.rules {
rule.Table = table
rule.Chain = savedChain.chain
conn.AddRule(rule)
}
}
}
return conn.Flush()
}
+50
View File
@@ -0,0 +1,50 @@
package nftables
import (
"context"
"errors"
"fmt"
"strings"
"github.com/google/nftables"
)
var ErrPolicyUnknown = errors.New("unknown policy")
// SetBaseChainsPolicy sets the policy of all the base chains (INPUT, FORWARD, or OUTPUT)
// for the filter table to the given policy (accept or drop).
func (f *Firewall) SetBaseChainsPolicy(_ context.Context, policy string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
var chainPolicy nftables.ChainPolicy
switch strings.ToLower(policy) {
case "accept":
chainPolicy = nftables.ChainPolicyAccept
case "drop":
chainPolicy = nftables.ChainPolicyDrop
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
_, inputChain, forwardChain, outputChain := setupFilterWithBaseChains(conn)
inputChain.Policy = &chainPolicy
forwardChain.Policy = &chainPolicy
outputChain.Policy = &chainPolicy
conn.AddChain(inputChain)
conn.AddChain(forwardChain)
conn.AddChain(outputChain)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing nftables changes: %w", err)
}
return nil
}
+61
View File
@@ -0,0 +1,61 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptEstablishedRelatedTraffic(_ context.Context) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, outputChain := setupFilterWithBaseChains(conn)
ctStateExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4, //nolint:mnd
Mask: []byte{byte(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), 0x00, 0x00, 0x00},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
conn.AddRule(&nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: ctStateExprs,
})
conn.AddRule(&nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: ctStateExprs,
})
if err := conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+27
View File
@@ -0,0 +1,27 @@
package nftables
import (
"errors"
"fmt"
"reflect"
"github.com/google/nftables"
)
var errRuleToDeleteNotFound = errors.New("rule not found for removal")
func (f *Firewall) deleteRule(conn *nftables.Conn, rule *nftables.Rule) error {
for i, existing := range f.rules {
if !reflect.DeepEqual(existing, rule) {
continue
}
err := conn.DelRule(existing)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
f.rules[i], f.rules[len(f.rules)-1] = f.rules[len(f.rules)-1], f.rules[i]
f.rules = f.rules[:len(f.rules)-1]
return nil
}
return fmt.Errorf("%w: %#v", errRuleToDeleteNotFound, rule)
}
+38
View File
@@ -0,0 +1,38 @@
package nftables
import "github.com/google/nftables"
func setupFilterWithBaseChains(conn *nftables.Conn) (table *nftables.Table,
inputChain, forwardChain, outputChain *nftables.Chain,
) {
table = conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "filter",
})
inputChain = conn.AddChain(&nftables.Chain{
Name: "input",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
})
forwardChain = conn.AddChain(&nftables.Chain{
Name: "forward",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
outputChain = conn.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
})
return table, inputChain, forwardChain, outputChain
}
+22
View File
@@ -0,0 +1,22 @@
package nftables
import (
"sync"
"github.com/google/nftables"
)
type Firewall struct {
logger Logger
// rules are only rules added and tracked for later removal.
// Not all rules added are tracked for removal.
rules []*nftables.Rule
mutex sync.Mutex
}
func New(logger Logger) *Firewall {
return &Firewall{
logger: logger,
}
}
+170
View File
@@ -0,0 +1,170 @@
package nftables
import (
"context"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptInputThroughInterface(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(intf + "\x00"),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty or "*", the interface is not used as a filter.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (f *Firewall) AcceptInputToPort(_ context.Context, intf string, port uint16, remove bool) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
portBytes := []byte{byte(port >> 8), byte(port)} //nolint:mnd
const tcp, udp uint8 = 6, 17
protocols := []uint8{tcp, udp}
for _, protocol := range protocols {
const maxExprsLen = 7
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 9, Len: 1}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protocol}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, //nolint:mnd
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: portBytes},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
if !remove {
conn.AddRule(rule)
f.rules = append(f.rules, rule)
continue
}
err = f.deleteRule(conn, rule)
if err != nil {
return fmt.Errorf("deleting rule: %w", err)
}
}
err = conn.Flush()
if err != nil {
f.rules = f.rules[:len(f.rules)-len(protocols)]
return fmt.Errorf("flushing: %w", err)
}
return nil
}
func (f *Firewall) AcceptInputToSubnet(_ context.Context, intf string, subnet netip.Prefix) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, inputChain, _, _ := setupFilterWithBaseChains(conn)
const maxExprsLen = 5
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
var payloadOffset uint32
if subnet.Addr().Is4() {
payloadOffset = 16
} else {
payloadOffset = 24
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: payloadOffset,
Len: uint32(len(subnet.Addr().AsSlice())), //nolint:gosec
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: subnet.Addr().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: inputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+5
View File
@@ -0,0 +1,5 @@
package nftables
type Logger interface {
Warnf(format string, args ...any)
}
+78
View File
@@ -0,0 +1,78 @@
package nftables
import (
"context"
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func (f *Firewall) AcceptIpv6MulticastOutput(_ context.Context, intf string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
conn, err := nftables.New()
if err != nil {
return fmt.Errorf("creating nftables connection: %w", err)
}
table, _, _, outputChain := setupFilterWithBaseChains(conn)
const maxExprsLen = 6
exprs := make([]expr.Any, 0, maxExprsLen)
if intf != "" && intf != "*" {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte(intf + "\x00")},
)
}
// ff02::1:ff00:0/104 mask is 13 bytes of 0xff
mask := []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
addr := []byte{
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00,
} //nolint:mnd
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 24, // IPv6 Destination Address offset //nolint:mnd
Len: 16, //nolint:mnd
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 16, //nolint:mnd
Mask: mask,
Xor: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //nolint:mnd
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: addr,
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
rule := &nftables.Rule{
Table: table,
Chain: outputChain,
Exprs: exprs,
}
conn.AddRule(rule)
err = conn.Flush()
if err != nil {
return fmt.Errorf("flushing: %w", err)
}
return nil
}
+12
View File
@@ -0,0 +1,12 @@
package nftables
import "github.com/google/nftables"
func IsSupported() bool {
conn, err := nftables.New()
if err != nil {
return false
}
_, err = conn.ListTable("filter")
return err == nil
}
+18 -24
View File
@@ -1,44 +1,38 @@
package netlink
import (
"errors"
"fmt"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
"golang.org/x/sys/unix"
)
var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel")
func (n *NetLink) FlushConntrack() error {
conn, err := netfilter.Dial(nil)
if err != nil {
if !n.conntrackNetlink {
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
}
return fmt.Errorf("dialing netfilter: %w", err)
}
defer conn.Close()
const ipCtnlMsgCtDelete = netfilter.MessageType(2)
header := netfilter.Header{
SubsystemID: netfilter.NFSubsysCTNetlink,
MessageType: ipCtnlMsgCtDelete,
Family: unix.AF_UNSPEC,
Flags: netlink.Request | netlink.Acknowledge,
}
request, err := netfilter.MarshalNetlink(header, nil)
if err != nil {
return fmt.Errorf("encoding netlink request: %w", err)
}
_, err = conn.Query(request)
if err != nil {
if !n.conntrackNetlink {
err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported)
families := [...]netfilter.ProtoFamily{netfilter.ProtoIPv4, netfilter.ProtoIPv6}
for _, family := range families {
const IPCtnlMsgCtDelete = 2
request, err := netfilter.MarshalNetlink(
netfilter.Header{
SubsystemID: netfilter.NFSubsysCTNetlink,
MessageType: netfilter.MessageType(IPCtnlMsgCtDelete),
Family: family,
Flags: netlink.Request | netlink.Acknowledge,
},
nil)
if err != nil {
return fmt.Errorf("encoding netlink request: %w", err)
}
_, err = conn.Query(request)
if err != nil {
return fmt.Errorf("querying netlink request: %w", err)
}
return fmt.Errorf("querying netlink request: %w", err)
}
return nil
}
@@ -2,10 +2,6 @@
package netlink
import "errors"
var ErrConntrackNetlinkNotSupported = errors.New("error not implemented")
func (n *NetLink) FlushConntrack() error {
panic("not implemented")
}
+2 -10
View File
@@ -1,22 +1,14 @@
package netlink
import (
"github.com/qdm12/gluetun/internal/mod"
"github.com/qdm12/log"
)
import "github.com/qdm12/log"
type NetLink struct {
debugLogger DebugLogger
// Fixed state
conntrackNetlink bool
}
func New(debugLogger DebugLogger) *NetLink {
conntrackNetlink := mod.Probe("nf_conntrack_netlink") == nil
return &NetLink{
debugLogger: debugLogger,
conntrackNetlink: conntrackNetlink,
debugLogger: debugLogger,
}
}
+1 -1
View File
@@ -74,7 +74,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
}
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, iptables.ErrKernelModuleMissing) {
if errors.Is(err, iptables.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
+1 -1
View File
@@ -35,7 +35,7 @@ func getFirewall(t *testing.T) *firewall.Config {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil, nil)
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, iptables.ErrNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
+1 -1
View File
@@ -43,7 +43,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
if result.err != nil {
switch {
case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, iptables.ErrKernelModuleMissing):
case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable):
// silently discard IPv6 network unreachable errors since they are common