mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-14 04:20:04 +02:00
feat(firewall): atomic iptables operations
- all operations rollback on failure - disabling the firewall means rolling back to its state before enabling it - aligns with nftables atomicity feature
This commit is contained in:
+17
-43
@@ -22,9 +22,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return fmt.Errorf("disabling firewall: %w", err)
|
||||
}
|
||||
c.restore(ctx)
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
return nil
|
||||
@@ -41,37 +39,12 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) disable(ctx context.Context) (err error) {
|
||||
if err = c.impl.ClearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("clearing all rules: %w", err)
|
||||
}
|
||||
if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv4 policies: %w", err)
|
||||
}
|
||||
if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("setting ipv6 policies: %w", err)
|
||||
}
|
||||
|
||||
const remove = true
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing port redirections: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// To use in defered call when enabling the firewall.
|
||||
func (c *Config) fallbackToDisabled(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := c.disable(ctx); err != nil {
|
||||
c.logger.Error("failed reversing firewall changes: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
c.restore, err = c.impl.SaveAndRestore(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving firewall rules: %w", err)
|
||||
}
|
||||
|
||||
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -82,21 +55,21 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.fallbackToDisabled(ctx)
|
||||
c.restore(context.Background())
|
||||
}
|
||||
}()
|
||||
|
||||
const remove = false
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -117,7 +90,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
continue
|
||||
}
|
||||
localInterfaces[network.InterfaceName] = struct{}{}
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
|
||||
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
|
||||
}
|
||||
@@ -130,7 +103,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
// Allows packets from any IP address to go through eth0 / local network
|
||||
// to reach Gluetun.
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
|
||||
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -139,12 +112,12 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.redirectPorts(ctx, remove)
|
||||
err = c.redirectPorts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting ports: %w", err)
|
||||
}
|
||||
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
|
||||
return fmt.Errorf("running user defined post firewall rules: %w", err)
|
||||
}
|
||||
|
||||
@@ -214,8 +187,9 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
|
||||
func (c *Config) redirectPorts(ctx context.Context) (err error) {
|
||||
for _, portRedirection := range c.portRedirections {
|
||||
const remove = false
|
||||
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
|
||||
portRedirection.destinationPort, remove)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user