mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(firewall): atomic iptables operations
- all operations rollback on failure - disabling the firewall means rolling back to its state before enabling it - aligns with nftables atomicity feature
This commit is contained in:
+17
-43
@@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return fmt.Errorf("disabling firewall: %w", err)
|
||||
}
|
||||
c.restore(ctx)
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
return nil
|
||||
@@ -41,37 +39,12 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) disable(ctx context.Context) (err error) {
|
||||
if err = c.impl.ClearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("clearing all rules: %w", err)
|
||||
}
|
||||
if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv4 policies: %w", err)
|
||||
}
|
||||
if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv6 policies: %w", err)
|
||||
}
|
||||
|
||||
const remove = true
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing port redirections: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// To use in defered call when enabling the firewall.
|
||||
func (c *Config) fallbackToDisabled(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := c.disable(ctx); err != nil {
|
||||
c.logger.Error("failed reversing firewall changes: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
c.restore, err = c.impl.SaveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving firewall rules: %w", err)
|
||||
}
|
||||
|
||||
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -82,21 +55,21 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.fallbackToDisabled(ctx)
|
||||
c.restore(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
const remove = false
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -117,7 +90,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
localInterfaces[network.InterfaceName] = struct{}{}
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
|
||||
}
|
||||
@@ -130,7 +103,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
// Allows packets from any IP address to go through eth0 / local network
|
||||
// to reach Gluetun.
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -139,12 +112,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
err = c.redirectPorts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting ports: %w", err)
|
||||
}
|
||||
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
|
||||
return fmt.Errorf("running user defined post firewall rules: %w", err)
|
||||
}
|
||||
|
||||
@@ -214,8 +187,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
|
||||
func (c *Config) redirectPorts(ctx context.Context) (err error) {
|
||||
for _, portRedirection := range c.portRedirections {
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
portRedirection.destinationPort, remove)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ type Config struct {
|
||||
|
||||
// State
|
||||
enabled bool
|
||||
restore func(context.Context)
|
||||
vpnConnection models.Connection
|
||||
vpnIntf string
|
||||
outboundSubnets []netip.Prefix
|
||||
|
||||
@@ -20,20 +20,20 @@ type Logger interface {
|
||||
}
|
||||
|
||||
type firewallImpl interface { //nolint:interfacebloat
|
||||
AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error
|
||||
AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
SaveAndRestore(ctx context.Context) (restore func(context.Context), err error)
|
||||
AcceptEstablishedRelatedTraffic(ctx context.Context) error
|
||||
AcceptInputThroughInterface(ctx context.Context, intf string) error
|
||||
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
|
||||
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix, remove bool) error
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string, remove bool) error
|
||||
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix) error
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
|
||||
connection models.Connection, remove bool) error
|
||||
ClearAllRules(ctx context.Context) error
|
||||
RedirectPort(ctx context.Context, intf string, sourcePort,
|
||||
destinationPort uint16, remove bool) error
|
||||
RunUserPostRules(ctx context.Context, customRulesPath string, remove bool) error
|
||||
RunUserPostRules(ctx context.Context, customRulesPath string) error
|
||||
SetIPv4AllPolicies(ctx context.Context, policy string) error
|
||||
SetIPv6AllPolicies(ctx context.Context, policy string) error
|
||||
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SaveAndRestore saves the current iptables and ip6tables rules and
|
||||
// returns a restore function that can be called to restore the saved rules.
|
||||
func (c *Config) SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
return c.saveAndRestore(ctx)
|
||||
}
|
||||
|
||||
// callers MUST always lock both the [Config] iptablesMutex and the ip6tablesMutex
|
||||
// before calling this function. Note the restore function does not interact with mutexes
|
||||
// so the caller must make sure the mutexes are locked when calling the restore function.
|
||||
func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Context), err error) {
|
||||
restoreIPv4, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restoreIPv6, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
restoreIPv4(ctx)
|
||||
if restoreIPv6 != nil {
|
||||
restoreIPv6(ctx)
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv4 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
|
||||
// Callers of saveAndRestoreIPv6 MUST always lock the [Config] ip6tablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.Context), err error) {
|
||||
if c.ip6Tables == "" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn(fmt.Sprintf("restoring IPv6 iptables failed: %v: %s", err, output))
|
||||
}
|
||||
}
|
||||
return restore, nil
|
||||
}
|
||||
@@ -24,8 +24,23 @@ func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
|
||||
if err := c.runIP6tablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -33,11 +48,24 @@ func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []st
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv6(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if c.ip6Tables == "" {
|
||||
return nil
|
||||
}
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
|
||||
@@ -26,21 +26,6 @@ func appendOrDelete(remove bool) string {
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
fields := strings.Fields(rule)
|
||||
for i, field := range fields {
|
||||
switch field {
|
||||
case "-A", "--append":
|
||||
fields[i] = "--delete"
|
||||
case "-D", "--delete":
|
||||
fields[i] = "--append"
|
||||
}
|
||||
}
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
|
||||
@@ -57,8 +42,24 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, instructions)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionsNoSave(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -69,6 +70,19 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestoreIPv4(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
@@ -84,17 +98,6 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) ClearAllRules(ctx context.Context) error {
|
||||
tables := []string{"filter"}
|
||||
for _, table := range tables {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"-t " + table + " --flush", // flush all chains
|
||||
"-t " + table + " --delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
@@ -108,22 +111,19 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
"--append INPUT -i %s -j ACCEPT", intf))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string,
|
||||
destination netip.Prefix, remove bool,
|
||||
) error {
|
||||
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destination netip.Prefix) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destination.String())
|
||||
instruction := fmt.Sprintf("--append INPUT %s -d %s -j ACCEPT",
|
||||
interfaceFlag, destination.String())
|
||||
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
@@ -140,10 +140,10 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
"--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -194,15 +194,12 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
|
||||
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
|
||||
// all interfaces. If remove is true, the rule is removed instead of added.
|
||||
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context,
|
||||
intf string, remove bool,
|
||||
) error {
|
||||
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context, intf string) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag)
|
||||
instruction := fmt.Sprintf("--append OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT", interfaceFlag)
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
@@ -234,7 +231,17 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructions(ctx, []string{
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructionsNoSave(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
@@ -245,11 +252,12 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructions(ctx, []string{
|
||||
err = c.runIP6tablesInstructionsNoSave(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
@@ -260,6 +268,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
restore(ctx) // just in case
|
||||
errMessage := err.Error()
|
||||
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
|
||||
if !remove {
|
||||
@@ -273,7 +282,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -289,16 +298,17 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
}()
|
||||
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
var ipv4 bool
|
||||
var rule string
|
||||
@@ -325,10 +335,6 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
|
||||
continue
|
||||
}
|
||||
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstruction(ctx, rule)
|
||||
@@ -338,10 +344,9 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove b
|
||||
err = c.runIP6tablesInstruction(ctx, rule)
|
||||
}
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,19 @@ import (
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
|
||||
if err := c.runMixedIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
restore(ctx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -14,8 +25,25 @@ func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
c.iptablesMutex.Lock()
|
||||
c.ip6tablesMutex.Lock()
|
||||
defer c.iptablesMutex.Unlock()
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
restore, err := c.saveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
err = c.runIptablesInstructionNoSave(ctx, instruction)
|
||||
if err != nil {
|
||||
restore(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstructionNoSave(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstructionNoSave(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstructionNoSave(ctx, instruction)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user