Fallback to accepting only NEW output public traffic if conntrack netlink isn't supported

This commit is contained in:
Quentin McGaw
2026-02-26 15:53:07 +00:00
parent dfac2b2f1a
commit a37354426b
16 changed files with 302 additions and 36 deletions
+5 -5
View File
@@ -69,6 +69,11 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.flushExistingConnections(ctx)
if err != nil {
return fmt.Errorf("flushing existing connections: %w", err)
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err
}
@@ -121,11 +126,6 @@ func (c *Config) enable(ctx context.Context) (err error) {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
err = c.netlinker.FlushConntrack()
if err != nil {
c.logger.Warn("flushing conntrack failed: " + err.Error())
}
return nil
}
+27
View File
@@ -0,0 +1,27 @@
package firewall
import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/netlink"
)
// Note remove is a no-op if conntrack netlink is supported by the kernel.
func (c *Config) flushExistingConnections(ctx context.Context) error {
err := c.netlinker.FlushConntrack()
switch {
case err == nil:
return nil
case errors.Is(err, netlink.ErrConntrackNetlinkNotSupported):
c.logger.Debugf("falling back to marking and filtering unmarked packets because flush conntrack failed: %s", err)
err = c.impl.AcceptOutputPublicOnlyNewTraffic(ctx)
if err != nil {
return fmt.Errorf("accepting only new output public traffic: %w", err)
}
return nil
default:
return fmt.Errorf("flushing conntrack: %w", err)
}
}
+3 -1
View File
@@ -14,6 +14,7 @@ type CmdRunner interface {
type Logger interface {
Debug(s string)
Debugf(format string, args ...any)
Info(s string)
Warn(s string)
Error(s string)
@@ -25,8 +26,9 @@ type Netlinker interface {
type firewallImpl interface { //nolint:interfacebloat
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
+5
View File
@@ -2,9 +2,12 @@ package iptables
import (
"context"
"errors"
"sync"
)
var ErrKernelModuleMissing = errors.New("kernel module is missing for this operation")
type Config struct {
runner CmdRunner
logger Logger
@@ -14,6 +17,7 @@ type Config struct {
// Fixed state
ipTables string
ip6Tables string
modules kernelModules
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
@@ -32,5 +36,6 @@ func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error)
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
modules: newKernelModules(),
}, nil
}
+73
View File
@@ -141,12 +141,85 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string,
}
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
if !c.modules.nfConntrack.ok {
return fmt.Errorf("%w: %s", ErrKernelModuleMissing, c.modules.nfConntrack.name)
}
return c.runMixedIptablesInstructions(ctx, []string{
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
})
}
// AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept
// established or related packets with this mark only. This effectively forces
// previously established or related traffic to be blocked.
// If remove is true, the rules are removed instead of appended.
// If the relevant kernel modules (nf_conntrack, xt_conntrack and xt_connmark)
// are not available, it returns an error indicating which kernel module is missing.
func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error {
err := checkKernelModulesAreOK(c.modules.nfConntrack, c.modules.xtConntrack, c.modules.xtConnmark)
if err != nil {
return fmt.Errorf("checking kernel modules: %w", err)
}
ipv4PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("127.0.0.0/8"),
}
ipv6PrivatePrefixes := []netip.Prefix{
netip.MustParsePrefix("fc00::/7"),
netip.MustParsePrefix("fe80::/10"),
netip.MustParsePrefix("::1/128"),
}
var ipv4Instructions, ipv6Instructions []string //nolint:prealloc
appendToBoth := func(instruction string) {
ipv4Instructions = append(ipv4Instructions, instruction)
ipv6Instructions = append(ipv6Instructions, instruction)
}
appendToBoth("-N PUBLIC_ONLY")
for _, prefix := range ipv4PrivatePrefixes {
ipv4Instructions = append(ipv4Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
for _, prefix := range ipv6PrivatePrefixes {
ipv6Instructions = append(ipv6Instructions, fmt.Sprintf(
"-A PUBLIC_ONLY -d %s -j RETURN", prefix))
}
// Mark new connections with mark 0x567
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate NEW -j CONNMARK --set-mark 0x567")
// Drop related/established connections that made it through; marked connections would
// be directly accepted by the first rule in the OUTPUT chain (see below)
appendToBoth("-A PUBLIC_ONLY -m conntrack --ctstate RELATED,ESTABLISHED -j DROP")
// Set the PUBLIC_ONLY chain as the second rule in the OUTPUT chain, so that it is evaluated
// after the accept rule below, for performance reasons.
appendToBoth("-I OUTPUT -j PUBLIC_ONLY")
appendToBoth("-I OUTPUT -m conntrack --ctstate RELATED,ESTABLISHED -m connmark --mark 0x567 -j ACCEPT")
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions)
if err != nil {
restore(ctx)
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions)
if err != nil {
restore(ctx)
return err
}
return nil
}
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
+47
View File
@@ -0,0 +1,47 @@
package iptables
import (
"fmt"
"strings"
"github.com/qdm12/gluetun/internal/mod"
)
type kernelModules struct {
nfConntrack kernelModule
xtConnmark kernelModule
xtConntrack kernelModule
}
type kernelModule struct {
name string
ok bool
}
func newKernelModules() kernelModules {
var m kernelModules
nameToFieldPtr := map[string]*kernelModule{
"nf_conntrack_netlink": &m.nfConntrack,
"xt_connmark": &m.xtConnmark,
"xt_conntrack": &m.xtConntrack,
}
for name, fieldPtr := range nameToFieldPtr {
fieldPtr.name = name
err := mod.Probe(name)
fieldPtr.ok = err == nil
}
return m
}
func checkKernelModulesAreOK(modules ...kernelModule) error {
missing := make([]string, 0, len(modules))
for _, module := range modules {
if !module.ok {
missing = append(missing, module.name)
}
}
if len(missing) > 0 {
return fmt.Errorf("%w: %s", ErrKernelModuleMissing, strings.Join(missing, ", "))
}
return nil
}
+28 -7
View File
@@ -33,6 +33,8 @@ type chainRule struct {
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint
}
type mark struct {
@@ -293,6 +295,29 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
}
rule.mark = mark
i += consumed
case "connmark":
i++
connMark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing connmark: %w", err)
}
rule.connMark = connMark
i += consumed
case "CONNMARK":
i++
switch optionalFields[i] {
case "set":
i++
value, err := parseAny32bNumber(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing CONNMARK set value: %w", err)
}
rule.setMark = value
i++
default:
return fmt.Errorf("%w: unexpected %q after CONNMARK",
ErrChainRuleMalformed, optionalFields[i])
}
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
@@ -422,8 +447,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
@@ -433,13 +456,11 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
consumed++
}
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
value, err := parseAny32bNumber(optionalFields[consumed])
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
return mark{}, 0, fmt.Errorf("value malformed: %w", err)
}
m.value = uint(value)
m.value = value
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
+84 -12
View File
@@ -9,9 +9,19 @@ import (
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
operation operation
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
@@ -25,6 +35,8 @@ type iptablesInstruction struct {
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
connMark mark
setMark uint // only used for jump CONNMARK --set-mark
}
func (i *iptablesInstruction) setDefaults() {
@@ -65,6 +77,10 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case i.mark != rule.mark:
return false
case i.connMark != rule.connMark:
return false
case i.setMark != rule.setMark:
return false
default:
return true
}
@@ -113,13 +129,20 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.operation = opDelete
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.operation = opAppend
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.chain = value
case "-j", "--jump":
instruction.target = value
subConsumed, err := parseJumpFlag(fields[1:], instruction)
if err != nil {
return 0, fmt.Errorf("parsing jump flag: %w", err)
}
consumed += subConsumed
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match":
@@ -128,13 +151,11 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
n, err := parseAny32bNumber(value)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
return 0, fmt.Errorf("parsing mark value %q: %w", value, err)
}
instruction.mark.value = uint(value)
instruction.mark.value = n
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
@@ -182,7 +203,7 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags": // -m can have 1 or 2 values
case "--tcp-flags":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
@@ -199,6 +220,34 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
}
}
func parseJumpFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
instruction.target = fields[0]
// consumed in the caller already takes fields[0] into account
if instruction.target != "CONNMARK" {
return consumed, nil
}
// consumed already accounts for the "CONNMARK" value
const expectedFields = 3
if len(fields) < expectedFields {
return 0, fmt.Errorf("%w: jump CONNMARK requires at least two additional values",
ErrIptablesCommandMalformed)
}
switch fields[1] {
case "--set-mark":
n, err := parseAny32bNumber(fields[2])
if err != nil {
return 0, fmt.Errorf("parsing connmark mark value %q: %w", fields[2], err)
}
consumed++
instruction.setMark = n
default:
return consumed, fmt.Errorf("%w: unsupported jump CONNMARK with value: %s",
ErrIptablesCommandMalformed, fields[1])
}
consumed++
return consumed, nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
@@ -221,6 +270,13 @@ func parsePort(value string) (port uint16, err error) {
return uint16(portValue), nil
}
func parseAny32bNumber(mark string) (value uint, err error) {
const base = 0 // auto-detect
const bits = 32
n, err := strconv.ParseUint(mark, base, bits)
return uint(n), err
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
@@ -234,14 +290,30 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
// parse it twice.
case "mark":
consumed++
switch fields[consumed] {
case "!":
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
case "connmark":
consumed++
switch {
case len(fields[consumed:]) == 0 || strings.HasPrefix(fields[consumed], "-"):
// end or another flag
return consumed, nil
case fields[consumed] == "!":
consumed++
instruction.connMark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match connmark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
+5 -5
View File
@@ -33,9 +33,9 @@ func Test_parseIptablesInstruction(t *testing.T) {
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
table: "filter",
chain: "INPUT",
operation: opAppend,
},
},
"instruction_A": {
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
operation: opAppend,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
operation: opDelete,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
+1 -1
View File
@@ -64,7 +64,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
var ErrMarkMatchModuleMissing = errors.New("libxt_mark.so module is missing")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
+7
View File
@@ -1,6 +1,7 @@
package netlink
import (
"errors"
"fmt"
"github.com/mdlayher/netlink"
@@ -8,7 +9,13 @@ import (
"golang.org/x/sys/unix"
)
var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel")
func (n *NetLink) FlushConntrack() error {
if !n.conntrackNetlink {
return fmt.Errorf("%w", ErrConntrackNetlinkNotSupported)
}
conn, err := netfilter.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netfilter: %w", err)
@@ -2,6 +2,10 @@
package netlink
import "errors"
var ErrConntrackNetlinkNotSupported = errors.New("error not implemented")
func (n *NetLink) FlushConntrack() error {
panic("not implemented")
}
+10 -2
View File
@@ -1,14 +1,22 @@
package netlink
import "github.com/qdm12/log"
import (
"github.com/qdm12/gluetun/internal/mod"
"github.com/qdm12/log"
)
type NetLink struct {
debugLogger DebugLogger
// Fixed state
conntrackNetlink bool
}
func New(debugLogger DebugLogger) *NetLink {
conntrackNetlink := mod.Probe("nf_conntrack_netlink") == nil
return &NetLink{
debugLogger: debugLogger,
debugLogger: debugLogger,
conntrackNetlink: conntrackNetlink,
}
}
+1 -1
View File
@@ -74,7 +74,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
}
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, iptables.ErrMarkMatchModuleMissing) {
if errors.Is(err, iptables.ErrKernelModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
+1 -1
View File
@@ -35,7 +35,7 @@ func getFirewall(t *testing.T) *firewall.Config {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil, nil)
if errors.Is(err, iptables.ErrNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
+1 -1
View File
@@ -43,7 +43,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
if result.err != nil {
switch {
case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing):
case errors.Is(result.err, iptables.ErrKernelModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable):
// silently discard IPv6 network unreachable errors since they are common