From 2bb4deccd53f93b9c9aa1aebe372662adebe83a2 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 26 Feb 2026 22:58:52 +0000 Subject: [PATCH] feat(firewall): atomic iptables operations - all operations rollback on failure - disabling the firewall means rolling back to its state before enabling it - aligns with nftables atomicity feature --- internal/firewall/enable.go | 60 +++------- internal/firewall/firewall.go | 1 + internal/firewall/interfaces.go | 12 +- internal/firewall/iptables/atomic.go | 85 ++++++++++++++ internal/firewall/iptables/ip6tables.go | 34 +++++- internal/firewall/iptables/iptables.go | 129 +++++++++++----------- internal/firewall/iptables/iptablesmix.go | 34 +++++- 7 files changed, 238 insertions(+), 117 deletions(-) create mode 100644 internal/firewall/iptables/atomic.go diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index f1909b17..c0f9aa8d 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) { if !enabled { c.logger.Info("disabling...") - if err = c.disable(ctx); err != nil { - return fmt.Errorf("disabling firewall: %w", err) - } + c.restore(ctx) c.enabled = false c.logger.Info("disabled successfully") return nil @@ -41,37 +39,12 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) { return nil } -func (c *Config) disable(ctx context.Context) (err error) { - if err = c.impl.ClearAllRules(ctx); err != nil { - return fmt.Errorf("clearing all rules: %w", err) - } - if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil { - return fmt.Errorf("setting ipv4 policies: %w", err) - } - if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil { - return fmt.Errorf("setting ipv6 policies: %w", err) - } - - const remove = true - err = c.redirectPorts(ctx, remove) - if err != nil { - return fmt.Errorf("removing port redirections: %w", err) - } - - return nil -} - -// To use in defered call when enabling the firewall. -func (c *Config) fallbackToDisabled(ctx context.Context) { - if ctx.Err() != nil { - return - } - if err := c.disable(ctx); err != nil { - c.logger.Error("failed reversing firewall changes: " + err.Error()) - } -} - func (c *Config) enable(ctx context.Context) (err error) { + c.restore, err = c.impl.SaveAndRestore(ctx) + if err != nil { + return fmt.Errorf("saving firewall rules: %w", err) + } + if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil { return err } @@ -82,21 +55,21 @@ func (c *Config) enable(ctx context.Context) (err error) { defer func() { if err != nil { - c.fallbackToDisabled(ctx) + c.restore(context.Background()) } }() - const remove = false - // Loopback traffic - if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil { + if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil { return err } + + const remove = false if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil { return err } - if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil { + if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil { return err } @@ -117,7 +90,7 @@ func (c *Config) enable(ctx context.Context) (err error) { continue } localInterfaces[network.InterfaceName] = struct{}{} - err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove) + err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName) if err != nil { return fmt.Errorf("accepting IPv6 multicast output: %w", err) } @@ -130,7 +103,7 @@ func (c *Config) enable(ctx context.Context) (err error) { // Allows packets from any IP address to go through eth0 / local network // to reach Gluetun. for _, network := range c.localNetworks { - if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil { + if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil { return err } } @@ -139,12 +112,12 @@ func (c *Config) enable(ctx context.Context) (err error) { return err } - err = c.redirectPorts(ctx, remove) + err = c.redirectPorts(ctx) if err != nil { return fmt.Errorf("redirecting ports: %w", err) } - if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil { + if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil { return fmt.Errorf("running user defined post firewall rules: %w", err) } @@ -214,8 +187,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) { return nil } -func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) { +func (c *Config) redirectPorts(ctx context.Context) (err error) { for _, portRedirection := range c.portRedirections { + const remove = false err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort, portRedirection.destinationPort, remove) if err != nil { diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 73bfcc46..3b9b902a 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -23,6 +23,7 @@ type Config struct { // State enabled bool + restore func(context.Context) vpnConnection models.Connection vpnIntf string outboundSubnets []netip.Prefix diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index 80a1bab5..a1938f96 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -20,20 +20,20 @@ type Logger interface { } type firewallImpl interface { //nolint:interfacebloat - AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error - AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error + SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) + AcceptEstablishedRelatedTraffic(ctx context.Context) error + AcceptInputThroughInterface(ctx context.Context, intf string) error AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error - AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix, remove bool) error - AcceptIpv6MulticastOutput(ctx context.Context, intf string, remove bool) error + AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error + AcceptIpv6MulticastOutput(ctx context.Context, intf string) error AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr, subnet netip.Prefix, remove bool) error AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error AcceptOutputTrafficToVPN(ctx context.Context, intf string, connection models.Connection, remove bool) error - ClearAllRules(ctx context.Context) error RedirectPort(ctx context.Context, intf string, sourcePort, destinationPort uint16, remove bool) error - RunUserPostRules(ctx context.Context, customRulesPath string, remove bool) error + RunUserPostRules(ctx context.Context, customRulesPath string) error SetIPv4AllPolicies(ctx context.Context, policy string) error SetIPv6AllPolicies(ctx context.Context, policy string) error TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) ( diff --git a/internal/firewall/iptables/atomic.go b/internal/firewall/iptables/atomic.go new file mode 100644 index 00000000..5e6a995c --- /dev/null +++ b/internal/firewall/iptables/atomic.go @@ -0,0 +1,85 @@ +package iptables + +import ( + "context" + "fmt" + "os/exec" + "strings" +) + +// SaveAndRestore saves the current iptables and ip6tables rules and +// returns a restore function that can be called to restore the saved rules. +func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) { + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + return c.saveAndRestore(ctx) +} + +// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex +// before calling this function. Note the restore function does not interact with mutexes +// so the caller must make sure the mutexes are locked when calling the restore function. +func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) { + restoreIPv4, err := c.saveAndRestoreIPv4(ctx) + if err != nil { + return nil, err + } + restoreIPv6, err := c.saveAndRestoreIPv6(ctx) + if err != nil { + return nil, err + } + + restore = func(ctx context.Context) { + restoreIPv4(ctx) + if restoreIPv6 != nil { + restoreIPv6(ctx) + } + } + return restore, nil +} + +// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex +// before calling this function. +func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) { + cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec + data, err := c.runner.Run(cmd) + if err != nil { + return nil, fmt.Errorf("saving IPv4 iptables: %w", err) + } + + restore = func(ctx context.Context) { + cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec + cmd.Stdin = strings.NewReader(data) + output, err := c.runner.Run(cmd) + if err != nil { + c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output)) + } + } + return restore, nil +} + +// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex +// before calling this function. +func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) { + if c.ip6Tables == "" { + return nil, nil //nolint:nilnil + } + + cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec + data, err := c.runner.Run(cmd) + if err != nil { + return nil, fmt.Errorf("saving IPv6 iptables: %w", err) + } + + restore = func(ctx context.Context) { + cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec + cmd.Stdin = strings.NewReader(data) + output, err := c.runner.Run(cmd) + if err != nil { + c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output)) + } + } + return restore, nil +} diff --git a/internal/firewall/iptables/ip6tables.go b/internal/firewall/iptables/ip6tables.go index 5dcf14e1..6e096699 100644 --- a/internal/firewall/iptables/ip6tables.go +++ b/internal/firewall/iptables/ip6tables.go @@ -24,8 +24,23 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) ( } func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error { + c.ip6tablesMutex.Lock() // only one ip6tables command at once + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestoreIPv6(ctx) + if err != nil { + return err + } + err = c.runIP6tablesInstructionsNoSave(ctx, instructions) + if err != nil { + restore(ctx) + } + return err +} + +func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error { for _, instruction := range instructions { - if err := c.runIP6tablesInstruction(ctx, instruction); err != nil { + if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil { return err } } @@ -33,11 +48,24 @@ func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []st } func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error { + c.ip6tablesMutex.Lock() // only one ip6tables command at once + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestoreIPv6(ctx) + if err != nil { + return err + } + err = c.runIP6tablesInstructionNoSave(ctx, instruction) + if err != nil { + restore(ctx) + } + return err +} + +func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error { if c.ip6Tables == "" { return nil } - c.ip6tablesMutex.Lock() // only one ip6tables command at once - defer c.ip6tablesMutex.Unlock() if isDeleteMatchInstruction(instruction) { return deleteIPTablesRule(ctx, c.ip6Tables, instruction, diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index 64546742..486e9fab 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -26,21 +26,6 @@ func appendOrDelete(remove bool) string { return "--append" } -// flipRule changes an append rule in a delete rule or a delete rule into an -// append rule. -func flipRule(rule string) string { - fields := strings.Fields(rule) - for i, field := range fields { - switch field { - case "-A", "--append": - fields[i] = "--delete" - case "-D", "--delete": - fields[i] = "--append" - } - } - return strings.Join(fields, " ") -} - // Version obtains the version of the installed iptables. func (c *Config) Version(ctx context.Context) (string, error) { cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec @@ -57,8 +42,24 @@ func (c *Config) Version(ctx context.Context) (string, error) { } func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error { + c.iptablesMutex.Lock() + defer c.iptablesMutex.Unlock() + + restore, err := c.saveAndRestoreIPv4(ctx) + if err != nil { + return err + } + + err = c.runIptablesInstructionsNoSave(ctx, instructions) + if err != nil { + restore(ctx) + } + return err +} + +func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error { for _, instruction := range instructions { - if err := c.runIptablesInstruction(ctx, instruction); err != nil { + if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil { return err } } @@ -69,6 +70,19 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) c.iptablesMutex.Lock() // only one iptables command at once defer c.iptablesMutex.Unlock() + restore, err := c.saveAndRestoreIPv4(ctx) + if err != nil { + return err + } + + err = c.runIptablesInstructionNoSave(ctx, instruction) + if err != nil { + restore(ctx) + } + return err +} + +func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error { if isDeleteMatchInstruction(instruction) { return deleteIPTablesRule(ctx, c.ipTables, instruction, c.runner, c.logger) @@ -84,17 +98,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) return nil } -func (c *Config) ClearAllRules(ctx context.Context) error { - tables := []string{"filter"} - for _, table := range tables { - return c.runMixedIptablesInstructions(ctx, []string{ - "-t " + table + " --flush", // flush all chains - "-t " + table + " --delete-chain", // delete all chains - }) - } - return nil -} - func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error { switch policy { case "ACCEPT", "DROP": @@ -108,22 +111,19 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error { }) } -func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { +func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error { return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( - "%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, - )) + "--append INPUT -i %s -j ACCEPT", intf)) } -func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, - destination netip.Prefix, remove bool, -) error { +func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error { interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" } - instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT", - appendOrDelete(remove), interfaceFlag, destination.String()) + instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT", + interfaceFlag, destination.String()) if destination.Addr().Is4() { return c.runIptablesInstruction(ctx, instruction) @@ -140,10 +140,10 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, )) } -func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { +func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error { return c.runMixedIptablesInstructions(ctx, []string{ - fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), - fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), + "--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", + "--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", }) } @@ -194,15 +194,12 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context, // ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve // IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means // all interfaces. If remove is true, the rule is removed instead of added. -func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, - intf string, remove bool, -) error { +func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error { interfaceFlag := "-o " + intf if intf == "*" { // all interfaces interfaceFlag = "" } - instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", - appendOrDelete(remove), interfaceFlag) + instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag) return c.runIP6tablesInstruction(ctx, instruction) } @@ -234,7 +231,17 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, interfaceFlag = "" } - err = c.runIptablesInstructions(ctx, []string{ + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestore(ctx) + if err != nil { + return err + } + + err = c.runIptablesInstructionsNoSave(ctx, []string{ fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d", appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort), fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", @@ -245,11 +252,12 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, appendOrDelete(remove), interfaceFlag, destinationPort), }) if err != nil { + restore(ctx) return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w", sourcePort, destinationPort, intf, err) } - err = c.runIP6tablesInstructions(ctx, []string{ + err = c.runIP6tablesInstructionsNoSave(ctx, []string{ fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d", appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort), fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", @@ -260,6 +268,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, appendOrDelete(remove), interfaceFlag, destinationPort), }) if err != nil { + restore(ctx) // just in case errMessage := err.Error() if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") { if !remove { @@ -273,7 +282,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, return nil } -func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error { +func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error { file, err := os.OpenFile(filepath, os.O_RDONLY, 0) if os.IsNotExist(err) { return nil @@ -289,16 +298,17 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b return err } lines := strings.Split(string(b), "\n") - successfulRules := []string{} - defer func() { - // transaction-like rollback - if err == nil || ctx.Err() != nil { - return - } - for _, rule := range successfulRules { - _ = c.runIptablesInstruction(ctx, flipRule(rule)) - } - }() + + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestore(ctx) + if err != nil { + return err + } + for _, line := range lines { var ipv4 bool var rule string @@ -325,10 +335,6 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b continue } - if remove { - rule = flipRule(rule) - } - switch { case ipv4: err = c.runIptablesInstruction(ctx, rule) @@ -338,10 +344,9 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b err = c.runIP6tablesInstruction(ctx, rule) } if err != nil { + restore(ctx) return err } - - successfulRules = append(successfulRules, rule) } return nil } diff --git a/internal/firewall/iptables/iptablesmix.go b/internal/firewall/iptables/iptablesmix.go index 9c3d2dcb..0ea85bf4 100644 --- a/internal/firewall/iptables/iptablesmix.go +++ b/internal/firewall/iptables/iptablesmix.go @@ -5,8 +5,19 @@ import ( ) func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestore(ctx) + if err != nil { + return err + } + for _, instruction := range instructions { - if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil { + if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil { + restore(ctx) return err } } @@ -14,8 +25,25 @@ func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions } func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error { - if err := c.runIptablesInstruction(ctx, instruction); err != nil { + c.iptablesMutex.Lock() + c.ip6tablesMutex.Lock() + defer c.iptablesMutex.Unlock() + defer c.ip6tablesMutex.Unlock() + + restore, err := c.saveAndRestore(ctx) + if err != nil { return err } - return c.runIP6tablesInstruction(ctx, instruction) + err = c.runIptablesInstructionNoSave(ctx, instruction) + if err != nil { + restore(ctx) + } + return err +} + +func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error { + if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil { + return err + } + return c.runIP6tablesInstructionNoSave(ctx, instruction) }