feat(firewall): atomic iptables operations

- all operations rollback on failure
- disabling the firewall means rolling back to its state before enabling it
- aligns with nftables atomicity feature
This commit is contained in:
Quentin McGaw
2026-02-26 22:58:52 +00:00
parent 0d0c0fb143
commit 2bb4deccd5
7 changed files with 238 additions and 117 deletions
+17 -43
View File
@@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled { if !enabled {
c.logger.Info("disabling...") c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil { c.restore(ctx)
return fmt.Errorf("disabling firewall: %w", err)
}
c.enabled = false c.enabled = false
c.logger.Info("disabled successfully") c.logger.Info("disabled successfully")
return nil return nil
@@ -41,37 +39,12 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil return nil
} }
func (c *Config) disable(ctx context.Context) (err error) {
if err = c.impl.ClearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}
return nil
}
// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *Config) enable(ctx context.Context) (err error) { func (c *Config) enable(ctx context.Context) (err error) {
c.restore, err = c.impl.SaveAndRestore(ctx)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil { if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
return err return err
} }
@@ -82,21 +55,21 @@ func (c *Config) enable(ctx context.Context) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
c.fallbackToDisabled(ctx) c.restore(context.Background())
} }
}() }()
const remove = false
// Loopback traffic // Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil { if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
return err return err
} }
const remove = false
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil { if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err return err
} }
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil { if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err return err
} }
@@ -117,7 +90,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue continue
} }
localInterfaces[network.InterfaceName] = struct{}{} localInterfaces[network.InterfaceName] = struct{}{}
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove) err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
if err != nil { if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err) return fmt.Errorf("accepting IPv6 multicast output: %w", err)
} }
@@ -130,7 +103,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network // Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun. // to reach Gluetun.
for _, network := range c.localNetworks { for _, network := range c.localNetworks {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil { if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
return err return err
} }
} }
@@ -139,12 +112,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err return err
} }
err = c.redirectPorts(ctx, remove) err = c.redirectPorts(ctx)
if err != nil { if err != nil {
return fmt.Errorf("redirecting ports: %w", err) return fmt.Errorf("redirecting ports: %w", err)
} }
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil { if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err) return fmt.Errorf("running user defined post firewall rules: %w", err)
} }
@@ -214,8 +187,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
return nil return nil
} }
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) { func (c *Config) redirectPorts(ctx context.Context) (err error) {
for _, portRedirection := range c.portRedirections { for _, portRedirection := range c.portRedirections {
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort, err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove) portRedirection.destinationPort, remove)
if err != nil { if err != nil {
+1
View File
@@ -23,6 +23,7 @@ type Config struct {
// State // State
enabled bool enabled bool
restore func(context.Context)
vpnConnection models.Connection vpnConnection models.Connection
vpnIntf string vpnIntf string
outboundSubnets []netip.Prefix outboundSubnets []netip.Prefix
+6 -6
View File
@@ -20,20 +20,20 @@ type Logger interface {
} }
type firewallImpl interface { //nolint:interfacebloat type firewallImpl interface { //nolint:interfacebloat
AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error AcceptEstablishedRelatedTraffic(ctx context.Context) error
AcceptInputThroughInterface(ctx context.Context, intf string) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix, remove bool) error AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string, remove bool) error AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr, AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptOutputTrafficToVPN(ctx context.Context, intf string, AcceptOutputTrafficToVPN(ctx context.Context, intf string,
connection models.Connection, remove bool) error connection models.Connection, remove bool) error
ClearAllRules(ctx context.Context) error
RedirectPort(ctx context.Context, intf string, sourcePort, RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string, remove bool) error RunUserPostRules(ctx context.Context, customRulesPath string) error
SetIPv4AllPolicies(ctx context.Context, policy string) error SetIPv4AllPolicies(ctx context.Context, policy string) error
SetIPv6AllPolicies(ctx context.Context, policy string) error SetIPv6AllPolicies(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) ( TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
+85
View File
@@ -0,0 +1,85 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strings"
)
// SaveAndRestore saves the current iptables and ip6tables rules and
// returns a restore function that can be called to restore the saved rules.
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
return c.saveAndRestore(ctx)
}
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
// before calling this function. Note the restore function does not interact with mutexes
// so the caller must make sure the mutexes are locked when calling the restore function.
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return nil, err
}
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return nil, err
}
restore = func(ctx context.Context) {
restoreIPv4(ctx)
if restoreIPv6 != nil {
restoreIPv6(ctx)
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
if c.ip6Tables == "" {
return nil, nil //nolint:nilnil
}
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
+31 -3
View File
@@ -24,8 +24,23 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
} }
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error { func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions { for _, instruction := range instructions {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil { if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
return err return err
} }
} }
@@ -33,11 +48,24 @@ func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []st
} }
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error { func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
if c.ip6Tables == "" { if c.ip6Tables == "" {
return nil return nil
} }
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) { if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction, return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
+67 -62
View File
@@ -26,21 +26,6 @@ func appendOrDelete(remove bool) string {
return "--append" return "--append"
} }
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
fields := strings.Fields(rule)
for i, field := range fields {
switch field {
case "-A", "--append":
fields[i] = "--delete"
case "-D", "--delete":
fields[i] = "--append"
}
}
return strings.Join(fields, " ")
}
// Version obtains the version of the installed iptables. // Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) { func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
@@ -57,8 +42,24 @@ func (c *Config) Version(ctx context.Context) (string, error) {
} }
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error { func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions { for _, instruction := range instructions {
if err := c.runIptablesInstruction(ctx, instruction); err != nil { if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err return err
} }
} }
@@ -69,6 +70,19 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock() defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if isDeleteMatchInstruction(instruction) { if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction, return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger) c.runner, c.logger)
@@ -84,17 +98,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
return nil return nil
} }
func (c *Config) ClearAllRules(ctx context.Context) error {
tables := []string{"filter"}
for _, table := range tables {
return c.runMixedIptablesInstructions(ctx, []string{
"-t " + table + " --flush", // flush all chains
"-t " + table + " --delete-chain", // delete all chains
})
}
return nil
}
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error { func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
@@ -108,22 +111,19 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
}) })
} }
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, "--append INPUT -i %s -j ACCEPT", intf))
))
} }
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
destination netip.Prefix, remove bool,
) error {
interfaceFlag := "-i " + intf interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces if intf == "*" { // all interfaces
interfaceFlag = "" interfaceFlag = ""
} }
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT", instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String()) interfaceFlag, destination.String())
if destination.Addr().Is4() { if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction) return c.runIptablesInstruction(ctx, instruction)
@@ -140,10 +140,10 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string,
)) ))
} }
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{ return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), "--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), "--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
}) })
} }
@@ -194,15 +194,12 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve // ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means // IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added. // all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
intf string, remove bool,
) error {
interfaceFlag := "-o " + intf interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces if intf == "*" { // all interfaces
interfaceFlag = "" interfaceFlag = ""
} }
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
appendOrDelete(remove), interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
@@ -234,7 +231,17 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
interfaceFlag = "" interfaceFlag = ""
} }
err = c.runIptablesInstructions(ctx, []string{ c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d", fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort), appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -245,11 +252,12 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort), appendOrDelete(remove), interfaceFlag, destinationPort),
}) })
if err != nil { if err != nil {
restore(ctx)
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w", return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err) sourcePort, destinationPort, intf, err)
} }
err = c.runIP6tablesInstructions(ctx, []string{ err = c.runIP6tablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d", fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort), appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -260,6 +268,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort), appendOrDelete(remove), interfaceFlag, destinationPort),
}) })
if err != nil { if err != nil {
restore(ctx) // just in case
errMessage := err.Error() errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") { if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove { if !remove {
@@ -273,7 +282,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
return nil return nil
} }
func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0) file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
@@ -289,16 +298,17 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
return err return err
} }
lines := strings.Split(string(b), "\n") lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() { c.iptablesMutex.Lock()
// transaction-like rollback c.ip6tablesMutex.Lock()
if err == nil || ctx.Err() != nil { defer c.iptablesMutex.Unlock()
return defer c.ip6tablesMutex.Unlock()
}
for _, rule := range successfulRules { restore, err := c.saveAndRestore(ctx)
_ = c.runIptablesInstruction(ctx, flipRule(rule)) if err != nil {
} return err
}() }
for _, line := range lines { for _, line := range lines {
var ipv4 bool var ipv4 bool
var rule string var rule string
@@ -325,10 +335,6 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
continue continue
} }
if remove {
rule = flipRule(rule)
}
switch { switch {
case ipv4: case ipv4:
err = c.runIptablesInstruction(ctx, rule) err = c.runIptablesInstruction(ctx, rule)
@@ -338,10 +344,9 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
err = c.runIP6tablesInstruction(ctx, rule) err = c.runIP6tablesInstruction(ctx, rule)
} }
if err != nil { if err != nil {
restore(ctx)
return err return err
} }
successfulRules = append(successfulRules, rule)
} }
return nil return nil
} }
+31 -3
View File
@@ -5,8 +5,19 @@ import (
) )
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, instruction := range instructions { for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil { if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
restore(ctx)
return err return err
} }
} }
@@ -14,8 +25,25 @@ func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions
} }
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error { func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil { c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err return err
} }
return c.runIP6tablesInstruction(ctx, instruction) err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstructionNoSave(ctx, instruction)
} }