From a37354426b786599d2418c87b1827c429d0d2457 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 26 Feb 2026 15:53:07 +0000 Subject: [PATCH] Fallback to accepting only NEW output public traffic if conntrack netlink isn't supported --- internal/firewall/enable.go | 10 +-- internal/firewall/flush.go | 27 +++++++ internal/firewall/interfaces.go | 4 +- internal/firewall/iptables/firewall.go | 5 ++ internal/firewall/iptables/iptables.go | 73 +++++++++++++++++ internal/firewall/iptables/kernel.go | 47 +++++++++++ internal/firewall/iptables/list.go | 35 +++++++-- internal/firewall/iptables/parse.go | 96 ++++++++++++++++++++--- internal/firewall/iptables/parse_test.go | 10 +-- internal/firewall/iptables/tcp.go | 2 +- internal/netlink/conntrack_linux.go | 7 ++ internal/netlink/conntrack_unspecified.go | 4 + internal/netlink/netlink.go | 12 ++- internal/pmtud/pmtud.go | 2 +- internal/pmtud/tcp/helpers_test.go | 2 +- internal/pmtud/tcp/mss.go | 2 +- 16 files changed, 302 insertions(+), 36 deletions(-) create mode 100644 internal/firewall/flush.go create mode 100644 internal/firewall/iptables/kernel.go diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 28328e26..71fa67c3 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -69,6 +69,11 @@ func (c *Config) enable(ctx context.Context) (err error) { return err } + err = c.flushExistingConnections(ctx) + if err != nil { + return fmt.Errorf("flushing existing connections: %w", err) + } + if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil { return err } @@ -121,11 +126,6 @@ func (c *Config) enable(ctx context.Context) (err error) { return fmt.Errorf("running user defined post firewall rules: %w", err) } - err = c.netlinker.FlushConntrack() - if err != nil { - c.logger.Warn("flushing conntrack failed: " + err.Error()) - } - return nil } diff --git a/internal/firewall/flush.go b/internal/firewall/flush.go new file mode 100644 index 00000000..010927d1 --- /dev/null +++ b/internal/firewall/flush.go @@ -0,0 +1,27 @@ +package firewall + +import ( + "context" + "errors" + "fmt" + + "github.com/qdm12/gluetun/internal/netlink" +) + +// Note remove is a no-op if conntrack netlink is supported by the kernel. +func (c *Config) flushExistingConnections(ctx context.Context) error { + err := c.netlinker.FlushConntrack() + switch { + case err == nil: + return nil + case errors.Is(err, netlink.ErrConntrackNetlinkNotSupported): + c.logger.Debugf("falling back to marking and filtering unmarked packets because flush conntrack failed: %s", err) + err = c.impl.AcceptOutputPublicOnlyNewTraffic(ctx) + if err != nil { + return fmt.Errorf("accepting only new output public traffic: %w", err) + } + return nil + default: + return fmt.Errorf("flushing conntrack: %w", err) + } +} diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index 3352b3bb..74a2afc3 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -14,6 +14,7 @@ type CmdRunner interface { type Logger interface { Debug(s string) + Debugf(format string, args ...any) Info(s string) Warn(s string) Error(s string) @@ -25,8 +26,9 @@ type Netlinker interface { type firewallImpl interface { //nolint:interfacebloat SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) - AcceptEstablishedRelatedTraffic(ctx context.Context) error + AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error AcceptInputThroughInterface(ctx context.Context, intf string) error + AcceptEstablishedRelatedTraffic(ctx context.Context) error AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error AcceptIpv6MulticastOutput(ctx context.Context, intf string) error diff --git a/internal/firewall/iptables/firewall.go b/internal/firewall/iptables/firewall.go index aeedae63..c1e8a59f 100644 --- a/internal/firewall/iptables/firewall.go +++ b/internal/firewall/iptables/firewall.go @@ -2,9 +2,12 @@ package iptables import ( "context" + "errors" "sync" ) +var ErrKernelModuleMissing = errors.New("kernel module is missing for this operation") + type Config struct { runner CmdRunner logger Logger @@ -14,6 +17,7 @@ type Config struct { // Fixed state ipTables string ip6Tables string + modules kernelModules } func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) { @@ -32,5 +36,6 @@ func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) logger: logger, ipTables: iptables, ip6Tables: ip6tables, + modules: newKernelModules(), }, nil } diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index 486e9fab..8b873804 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -141,12 +141,85 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, } func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error { + if !c.modules.nfConntrack.ok { + return fmt.Errorf("%w: %s", ErrKernelModuleMissing, c.modules.nfConntrack.name) + } return c.runMixedIptablesInstructions(ctx, []string{ "--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", "--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", }) } +// AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept +// established or related packets with this mark only. This effectively forces +// previously established or related traffic to be blocked. +// If remove is true, the rules are removed instead of appended. +// If the relevant kernel modules (nf_conntrack, xt_conntrack and xt_connmark) +// are not available, it returns an error indicating which kernel module is missing. +func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error { + err := checkKernelModulesAreOK(c.modules.nfConntrack, c.modules.xtConntrack, c.modules.xtConnmark) + if err != nil { + return fmt.Errorf("checking kernel modules: %w", err) + } + + ipv4PrivatePrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("127.0.0.0/8"), + } + ipv6PrivatePrefixes := []netip.Prefix{ + netip.MustParsePrefix("fc00::/7"), + netip.MustParsePrefix("fe80::/10"), + netip.MustParsePrefix("::1/128"), + } + var ipv4Instructions, ipv6Instructions []string //nolint:prealloc + appendToBoth := func(instruction string) { + ipv4Instructions = append(ipv4Instructions, instruction) + ipv6Instructions = append(ipv6Instructions, instruction) + } + appendToBoth("-N PUBLIC_ONLY") + for _, prefix := range ipv4PrivatePrefixes { + ipv4Instructions = append(ipv4Instructions, fmt.Sprintf( + "-A PUBLIC_ONLY -d %s -j RETURN", prefix)) + } + for _, prefix := range ipv6PrivatePrefixes { + ipv6Instructions = append(ipv6Instructions, fmt.Sprintf( + "-A PUBLIC_ONLY -d %s -j RETURN", prefix)) + } + // Mark new connections with mark 0x567 + appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate NEW -j CONNMARK --set-mark 0x567") + // Drop related/established connections that made it through; marked connections would + // be directly accepted by the first rule in the OUTPUT chain (see below) + appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j DROP") + // Set the PUBLIC_ONLY chain as the second rule in the OUTPUT chain, so that it is evaluated + // after the accept rule below, for performance reasons. + appendToBoth("-I OUTPUT -j PUBLIC_ONLY") + appendToBoth("-I OUTPUT -m conntrack --ctstate RELATED,ESTABLISHED -m connmark --mark 0x567 -j ACCEPT") + + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestore(ctx) + if err != nil { + return err + } + + err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions) + if err != nil { + restore(ctx) + return err + } + err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions) + if err != nil { + restore(ctx) + return err + } + return nil +} + func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.Connection, remove bool, ) error { diff --git a/internal/firewall/iptables/kernel.go b/internal/firewall/iptables/kernel.go new file mode 100644 index 00000000..5b506e42 --- /dev/null +++ b/internal/firewall/iptables/kernel.go @@ -0,0 +1,47 @@ +package iptables + +import ( + "fmt" + "strings" + + "github.com/qdm12/gluetun/internal/mod" +) + +type kernelModules struct { + nfConntrack kernelModule + xtConnmark kernelModule + xtConntrack kernelModule +} + +type kernelModule struct { + name string + ok bool +} + +func newKernelModules() kernelModules { + var m kernelModules + nameToFieldPtr := map[string]*kernelModule{ + "nf_conntrack_netlink": &m.nfConntrack, + "xt_connmark": &m.xtConnmark, + "xt_conntrack": &m.xtConntrack, + } + for name, fieldPtr := range nameToFieldPtr { + fieldPtr.name = name + err := mod.Probe(name) + fieldPtr.ok = err == nil + } + return m +} + +func checkKernelModulesAreOK(modules ...kernelModule) error { + missing := make([]string, 0, len(modules)) + for _, module := range modules { + if !module.ok { + missing = append(missing, module.name) + } + } + if len(missing) > 0 { + return fmt.Errorf("%w: %s", ErrKernelModuleMissing, strings.Join(missing, ", ")) + } + return nil +} diff --git a/internal/firewall/iptables/list.go b/internal/firewall/iptables/list.go index 49f855fe..38d918c5 100644 --- a/internal/firewall/iptables/list.go +++ b/internal/firewall/iptables/list.go @@ -33,6 +33,8 @@ type chainRule struct { ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. tcpFlags tcpFlags mark mark + connMark mark + setMark uint } type mark struct { @@ -293,6 +295,29 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err } rule.mark = mark i += consumed + case "connmark": + i++ + connMark, consumed, err := parseMark(optionalFields[i:]) + if err != nil { + return fmt.Errorf("parsing connmark: %w", err) + } + rule.connMark = connMark + i += consumed + case "CONNMARK": + i++ + switch optionalFields[i] { + case "set": + i++ + value, err := parseAny32bNumber(optionalFields[i]) + if err != nil { + return fmt.Errorf("parsing CONNMARK set value: %w", err) + } + rule.setMark = value + i++ + default: + return fmt.Errorf("%w: unexpected %q after CONNMARK", + ErrChainRuleMalformed, optionalFields[i]) + } default: return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, optionalFields[i]) @@ -422,8 +447,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) { return ports, nil } -var errMarkValueMalformed = errors.New("mark value is malformed") - func parseMark(optionalFields []string) (m mark, consumed int, err error) { switch optionalFields[consumed] { case "match": @@ -433,13 +456,11 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) { consumed++ } - const base = 0 // auto-detect - const bits = 32 - value, err := strconv.ParseUint(optionalFields[consumed], base, bits) + value, err := parseAny32bNumber(optionalFields[consumed]) if err != nil { - return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed]) + return mark{}, 0, fmt.Errorf("value malformed: %w", err) } - m.value = uint(value) + m.value = value consumed++ default: return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s", diff --git a/internal/firewall/iptables/parse.go b/internal/firewall/iptables/parse.go index a18d7714..fda21aaf 100644 --- a/internal/firewall/iptables/parse.go +++ b/internal/firewall/iptables/parse.go @@ -9,9 +9,19 @@ import ( "strings" ) +type operation uint8 + +const ( + opNone operation = iota + opAppend + opDelete + opInsert + opReplace +) + type iptablesInstruction struct { table string // defaults to "filter", and can be "nat" for example. - append bool + operation operation chain string // for example INPUT, PREROUTING. Cannot be empty. target string // for example ACCEPT. Can be empty. protocol string // "tcp" or "udp" or "" for all protocols. @@ -25,6 +35,8 @@ type iptablesInstruction struct { ctstate []string // if empty, there is no ctstate tcpFlags tcpFlags mark mark + connMark mark + setMark uint // only used for jump CONNMARK --set-mark } func (i *iptablesInstruction) setDefaults() { @@ -65,6 +77,10 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) ( return false case i.mark != rule.mark: return false + case i.connMark != rule.connMark: + return false + case i.setMark != rule.setMark: + return false default: return true } @@ -113,13 +129,20 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co case "-t", "--table": instruction.table = value case "-D", "--delete": - instruction.append = false + instruction.operation = opDelete instruction.chain = value case "-A", "--append": - instruction.append = true + instruction.operation = opAppend + instruction.chain = value + case "-I", "--insert": + instruction.operation = opInsert instruction.chain = value case "-j", "--jump": - instruction.target = value + subConsumed, err := parseJumpFlag(fields[1:], instruction) + if err != nil { + return 0, fmt.Errorf("parsing jump flag: %w", err) + } + consumed += subConsumed case "-p", "--protocol": instruction.protocol = value case "-m", "--match": @@ -128,13 +151,11 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co return 0, fmt.Errorf("parsing match module: %w", err) } case "--mark": - const base = 0 // auto-detect - const bits = 32 - value, err := strconv.ParseUint(value, base, bits) + n, err := parseAny32bNumber(value) if err != nil { - return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err) + return 0, fmt.Errorf("parsing mark value %q: %w", value, err) } - instruction.mark.value = uint(value) + instruction.mark.value = n case "-i", "--in-interface": instruction.inputInterface = value case "-o", "--out-interface": @@ -182,7 +203,7 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) { flag := fields[0] // All flags use one value after the flag, except the following: switch flag { - case "--tcp-flags": // -m can have 1 or 2 values + case "--tcp-flags": const expected = 3 if len(fields) < expected { return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", @@ -199,6 +220,34 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) { } } +func parseJumpFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) { + instruction.target = fields[0] + // consumed in the caller already takes fields[0] into account + if instruction.target != "CONNMARK" { + return consumed, nil + } + // consumed already accounts for the "CONNMARK" value + const expectedFields = 3 + if len(fields) < expectedFields { + return 0, fmt.Errorf("%w: jump CONNMARK requires at least two additional values", + ErrIptablesCommandMalformed) + } + switch fields[1] { + case "--set-mark": + n, err := parseAny32bNumber(fields[2]) + if err != nil { + return 0, fmt.Errorf("parsing connmark mark value %q: %w", fields[2], err) + } + consumed++ + instruction.setMark = n + default: + return consumed, fmt.Errorf("%w: unsupported jump CONNMARK with value: %s", + ErrIptablesCommandMalformed, fields[1]) + } + consumed++ + return consumed, nil +} + func parseIPPrefix(value string) (prefix netip.Prefix, err error) { slashIndex := strings.Index(value, "/") if slashIndex >= 0 { @@ -221,6 +270,13 @@ func parsePort(value string) (port uint16, err error) { return uint16(portValue), nil } +func parseAny32bNumber(mark string) (value uint, err error) { + const base = 0 // auto-detect + const bits = 32 + n, err := strconv.ParseUint(mark, base, bits) + return uint(n), err +} + func parseMatchModule(fields []string, instruction *iptablesInstruction) ( consumed int, err error, ) { @@ -234,14 +290,30 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) ( // parse it twice. case "mark": consumed++ - switch fields[consumed] { - case "!": + switch { + case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"): + // end or another flag + return consumed, nil + case fields[consumed] == "!": consumed++ instruction.mark.invert = true default: return consumed, fmt.Errorf("%w: unsupported match mark with value: %s", ErrIptablesCommandMalformed, fields[2]) } + case "connmark": + consumed++ + switch { + case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"): + // end or another flag + return consumed, nil + case fields[consumed] == "!": + consumed++ + instruction.connMark.invert = true + default: + return consumed, fmt.Errorf("%w: unsupported match connmark with value: %s", + ErrIptablesCommandMalformed, fields[2]) + } default: return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, fields[consumed]) diff --git a/internal/firewall/iptables/parse_test.go b/internal/firewall/iptables/parse_test.go index 51c5ab7c..cdf66b83 100644 --- a/internal/firewall/iptables/parse_test.go +++ b/internal/firewall/iptables/parse_test.go @@ -33,9 +33,9 @@ func Test_parseIptablesInstruction(t *testing.T) { "one_pair": { s: "-A INPUT", instruction: iptablesInstruction{ - table: "filter", - chain: "INPUT", - append: true, + table: "filter", + chain: "INPUT", + operation: opAppend, }, }, "instruction_A": { @@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) { instruction: iptablesInstruction{ table: "filter", chain: "INPUT", - append: true, + operation: opAppend, inputInterface: "tun0", protocol: "tcp", source: netip.MustParsePrefix("1.2.3.4/32"), @@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) { instruction: iptablesInstruction{ table: "nat", chain: "PREROUTING", - append: false, + operation: opDelete, inputInterface: "tun0", protocol: "tcp", destinationPort: 43716, diff --git a/internal/firewall/iptables/tcp.go b/internal/firewall/iptables/tcp.go index 77e5c5f2..99dbde36 100644 --- a/internal/firewall/iptables/tcp.go +++ b/internal/firewall/iptables/tcp.go @@ -64,7 +64,7 @@ func parseTCPFlag(s string) (tcpFlag, error) { return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s) } -var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so") +var ErrMarkMatchModuleMissing = errors.New("libxt_mark.so module is missing") // TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port, // for any TCP packets not marked with the excludeMark given. diff --git a/internal/netlink/conntrack_linux.go b/internal/netlink/conntrack_linux.go index 868e25e8..08832077 100644 --- a/internal/netlink/conntrack_linux.go +++ b/internal/netlink/conntrack_linux.go @@ -1,6 +1,7 @@ package netlink import ( + "errors" "fmt" "github.com/mdlayher/netlink" @@ -8,7 +9,13 @@ import ( "golang.org/x/sys/unix" ) +var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel") + func (n *NetLink) FlushConntrack() error { + if !n.conntrackNetlink { + return fmt.Errorf("%w", ErrConntrackNetlinkNotSupported) + } + conn, err := netfilter.Dial(nil) if err != nil { return fmt.Errorf("dialing netfilter: %w", err) diff --git a/internal/netlink/conntrack_unspecified.go b/internal/netlink/conntrack_unspecified.go index d2652b26..5ed2a24a 100644 --- a/internal/netlink/conntrack_unspecified.go +++ b/internal/netlink/conntrack_unspecified.go @@ -2,6 +2,10 @@ package netlink +import "errors" + +var ErrConntrackNetlinkNotSupported = errors.New("error not implemented") + func (n *NetLink) FlushConntrack() error { panic("not implemented") } diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index 9a26ab08..bc295082 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -1,14 +1,22 @@ package netlink -import "github.com/qdm12/log" +import ( + "github.com/qdm12/gluetun/internal/mod" + "github.com/qdm12/log" +) type NetLink struct { debugLogger DebugLogger + + // Fixed state + conntrackNetlink bool } func New(debugLogger DebugLogger) *NetLink { + conntrackNetlink := mod.Probe("nf_conntrack_netlink") == nil return &NetLink{ - debugLogger: debugLogger, + debugLogger: debugLogger, + conntrackNetlink: conntrackNetlink, } } diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index 4505c0e9..63fe6715 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -74,7 +74,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net } mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger) if err != nil { - if errors.Is(err, iptables.ErrMarkMatchModuleMissing) { + if errors.Is(err, iptables.ErrKernelModuleMissing) { logger.Debugf("aborting TCP path MTU discovery: %s", err) if icmpSuccess { return maxPossibleMTU, nil // only rely on ICMP PMTUD results diff --git a/internal/pmtud/tcp/helpers_test.go b/internal/pmtud/tcp/helpers_test.go index e5a21e05..e0ad925d 100644 --- a/internal/pmtud/tcp/helpers_test.go +++ b/internal/pmtud/tcp/helpers_test.go @@ -35,7 +35,7 @@ func getFirewall(t *testing.T) *firewall.Config { noopLogger := &noopLogger{} cmder := command.New() var err error - testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) + testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil, nil) if errors.Is(err, iptables.ErrNotSupported) { t.Skip("iptables not installed, skipping TCP PMTUD tests") } diff --git a/internal/pmtud/tcp/mss.go b/internal/pmtud/tcp/mss.go index bedc2a5d..f92467aa 100644 --- a/internal/pmtud/tcp/mss.go +++ b/internal/pmtud/tcp/mss.go @@ -43,7 +43,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr if result.err != nil { switch { case err != nil: // error already occurred for another findMSS goroutine - case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing): + case errors.Is(result.err, iptables.ErrKernelModuleMissing): err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err) case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable): // silently discard IPv6 network unreachable errors since they are common