feat(firewall): atomic iptables operations

- all operations rollback on failure
- disabling the firewall means rolling back to its state before enabling it
- aligns with nftables atomicity feature
This commit is contained in:
Quentin McGaw
2026-02-26 22:58:52 +00:00
parent 0d0c0fb143
commit 2bb4deccd5
7 changed files with 238 additions and 117 deletions
+85
View File
@@ -0,0 +1,85 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strings"
)
// SaveAndRestore saves the current iptables and ip6tables rules and
// returns a restore function that can be called to restore the saved rules.
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
return c.saveAndRestore(ctx)
}
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
// before calling this function. Note the restore function does not interact with mutexes
// so the caller must make sure the mutexes are locked when calling the restore function.
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return nil, err
}
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return nil, err
}
restore = func(ctx context.Context) {
restoreIPv4(ctx)
if restoreIPv6 != nil {
restoreIPv6(ctx)
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
// before calling this function.
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
if c.ip6Tables == "" {
return nil, nil //nolint:nilnil
}
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
data, err := c.runner.Run(cmd)
if err != nil {
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
}
restore = func(ctx context.Context) {
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
cmd.Stdin = strings.NewReader(data)
output, err := c.runner.Run(cmd)
if err != nil {
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
}
}
return restore, nil
}
+31 -3
View File
@@ -24,8 +24,23 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
}
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
}
@@ -33,11 +48,24 @@ func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []st
}
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv6(ctx)
if err != nil {
return err
}
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
if c.ip6Tables == "" {
return nil
}
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
+67 -62
View File
@@ -26,21 +26,6 @@ func appendOrDelete(remove bool) string {
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
fields := strings.Fields(rule)
for i, field := range fields {
switch field {
case "-A", "--append":
fields[i] = "--delete"
case "-D", "--delete":
fields[i] = "--append"
}
}
return strings.Join(fields, " ")
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
@@ -57,8 +42,24 @@ func (c *Config) Version(ctx context.Context) (string, error) {
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, instructions)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
}
@@ -69,6 +70,19 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
restore, err := c.saveAndRestoreIPv4(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
@@ -84,17 +98,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
return nil
}
func (c *Config) ClearAllRules(ctx context.Context) error {
tables := []string{"filter"}
for _, table := range tables {
return c.runMixedIptablesInstructions(ctx, []string{
"-t " + table + " --flush", // flush all chains
"-t " + table + " --delete-chain", // delete all chains
})
}
return nil
}
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
@@ -108,22 +111,19 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
})
}
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
"--append INPUT -i %s -j ACCEPT", intf))
}
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool,
) error {
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
@@ -140,10 +140,10 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string,
))
}
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
})
}
@@ -194,15 +194,12 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool,
) error {
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
appendOrDelete(remove), interfaceFlag)
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -234,7 +231,17 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
interfaceFlag = ""
}
err = c.runIptablesInstructions(ctx, []string{
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
err = c.runIptablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -245,11 +252,12 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx)
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructions(ctx, []string{
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
@@ -260,6 +268,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx) // just in case
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
@@ -273,7 +282,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
return nil
}
func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error {
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
@@ -289,16 +298,17 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
return err
}
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, line := range lines {
var ipv4 bool
var rule string
@@ -325,10 +335,6 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
continue
}
if remove {
rule = flipRule(rule)
}
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
@@ -338,10 +344,9 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
err = c.runIP6tablesInstruction(ctx, rule)
}
if err != nil {
restore(ctx)
return err
}
successfulRules = append(successfulRules, rule)
}
return nil
}
+31 -3
View File
@@ -5,8 +5,19 @@ import (
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
restore(ctx)
return err
}
}
@@ -14,8 +25,25 @@ func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
c.iptablesMutex.Lock()
c.ip6tablesMutex.Lock()
defer c.iptablesMutex.Unlock()
defer c.ip6tablesMutex.Unlock()
restore, err := c.saveAndRestore(ctx)
if err != nil {
return err
}
return c.runIP6tablesInstruction(ctx, instruction)
err = c.runIptablesInstructionNoSave(ctx, instruction)
if err != nil {
restore(ctx)
}
return err
}
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstructionNoSave(ctx, instruction)
}