diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c222a632..49b19cc7 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -227,7 +227,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, firewallLogger.Patch(log.SetLevel(log.LevelDebug)) } firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder, - defaultRoutes, localNetworks) + netLinker, defaultRoutes, localNetworks) if err != nil { return err } @@ -237,10 +237,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, if err != nil { return err } - err = netLinker.FlushConntrack() - if err != nil { - logger.Warnf("flushing conntrack failed: %s", err) - } } // TODO run this in a loop or in openvpn to reload from file without restarting diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index c0f9aa8d..28328e26 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -121,6 +121,11 @@ func (c *Config) enable(ctx context.Context) (err error) { return fmt.Errorf("running user defined post firewall rules: %w", err) } + err = c.netlinker.FlushConntrack() + if err != nil { + c.logger.Warn("flushing conntrack failed: " + err.Error()) + } + return nil } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 3b9b902a..aba68831 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -13,6 +13,7 @@ import ( type Config struct { runner CmdRunner + netlinker Netlinker logger Logger defaultRoutes []routing.DefaultRoute localNetworks []routing.LocalNetwork @@ -35,8 +36,8 @@ type Config struct { // NewConfig creates a new Config instance and returns an error // if no iptables implementation is available. func NewConfig(ctx context.Context, logger Logger, - runner CmdRunner, defaultRoutes []routing.DefaultRoute, - localNetworks []routing.LocalNetwork, + runner CmdRunner, netlinker Netlinker, + defaultRoutes []routing.DefaultRoute, localNetworks []routing.LocalNetwork, ) (config *Config, err error) { impl, err := iptables.New(ctx, runner, logger) if err != nil { @@ -45,6 +46,7 @@ func NewConfig(ctx context.Context, logger Logger, return &Config{ runner: runner, + netlinker: netlinker, logger: logger, allowedInputPorts: make(map[uint16]map[string]struct{}), // Obtained from routing diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index a1938f96..3352b3bb 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -19,6 +19,10 @@ type Logger interface { Error(s string) } +type Netlinker interface { + FlushConntrack() error +} + type firewallImpl interface { //nolint:interfacebloat SaveAndRestore(ctx context.Context) (restore func(context.Context), err error) AcceptEstablishedRelatedTraffic(ctx context.Context) error