feat(firewall): atomic iptables operations

- all operations rollback on failure
- disabling the firewall means rolling back to its state before enabling it
- aligns with nftables atomicity feature
This commit is contained in:
Quentin McGaw
2026-02-26 22:58:52 +00:00
parent 0d0c0fb143
commit 2bb4deccd5
7 changed files with 238 additions and 117 deletions
+17 -43
View File
@@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled {
c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil {
return fmt.Errorf("disabling firewall: %w", err)
}
c.restore(ctx)
c.enabled = false
c.logger.Info("disabled successfully")
return nil
@@ -41,37 +39,12 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil
}
func (c *Config) disable(ctx context.Context) (err error) {
if err = c.impl.ClearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}
return nil
}
// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *Config) enable(ctx context.Context) (err error) {
c.restore, err = c.impl.SaveAndRestore(ctx)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
}
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
@@ -82,21 +55,21 @@ func (c *Config) enable(ctx context.Context) (err error) {
defer func() {
if err != nil {
c.fallbackToDisabled(ctx)
c.restore(context.Background())
}
}()
const remove = false
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil {
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
return err
}
const remove = false
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil {
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err
}
@@ -117,7 +90,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue
}
localInterfaces[network.InterfaceName] = struct{}{}
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
}
@@ -130,7 +103,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
return err
}
}
@@ -139,12 +112,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}
err = c.redirectPorts(ctx, remove)
err = c.redirectPorts(ctx)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil {
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
@@ -214,8 +187,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
return nil
}
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
func (c *Config) redirectPorts(ctx context.Context) (err error) {
for _, portRedirection := range c.portRedirections {
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {