diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index c0f9aa8d..6f5cb18b 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -45,6 +45,12 @@ func (c *Config) enable(ctx context.Context) (err error) { return fmt.Errorf("saving firewall rules: %w", err) } + defer func() { + if err != nil { + c.restore(context.Background()) + } + }() + if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil { return err } @@ -53,12 +59,6 @@ func (c *Config) enable(ctx context.Context) (err error) { return err } - defer func() { - if err != nil { - c.restore(context.Background()) - } - }() - // Loopback traffic if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil { return err diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index 486e9fab..d97b306b 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -337,11 +337,11 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error { switch { case ipv4: - err = c.runIptablesInstruction(ctx, rule) + err = c.runIptablesInstructionNoSave(ctx, rule) case c.ip6Tables == "": err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables) default: // ipv6 - err = c.runIP6tablesInstruction(ctx, rule) + err = c.runIP6tablesInstructionNoSave(ctx, rule) } if err != nil { restore(ctx) diff --git a/internal/firewall/iptables/iptablesmix.go b/internal/firewall/iptables/iptablesmix.go index 0ea85bf4..32c75c45 100644 --- a/internal/firewall/iptables/iptablesmix.go +++ b/internal/firewall/iptables/iptablesmix.go @@ -34,7 +34,7 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st if err != nil { return err } - err = c.runIptablesInstructionNoSave(ctx, instruction) + err = c.runMixedIptablesInstructionNoSave(ctx, instruction) if err != nil { restore(ctx) }