diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index cd6c3100..2d571351 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -141,13 +141,14 @@ func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, } func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error { - if !c.modules.nfConntrack.ok { - return fmt.Errorf("%w: %s", ErrKernelModuleMissing, c.modules.nfConntrack.name) - } - return c.runMixedIptablesInstructions(ctx, []string{ + err := c.runMixedIptablesInstructions(ctx, []string{ "--append OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", "--append INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", }) + if err != nil && !c.modules.nfConntrack.ok { + return fmt.Errorf("%w: %s", ErrKernelModuleMissing, c.modules.nfConntrack.name) + } + return err } // AcceptOutputPublicOnlyNewTraffic adds rules to mark new output connections, and to accept @@ -157,11 +158,6 @@ func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context) error { // If the relevant kernel modules (nf_conntrack, xt_conntrack and xt_connmark) // are not available, it returns an error indicating which kernel module is missing. func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error { - err := checkKernelModulesAreOK(c.modules.nfConntrack, c.modules.xtConntrack, c.modules.xtConnmark) - if err != nil { - return fmt.Errorf("checking kernel modules: %w", err) - } - ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions() appendToBoth := func(instruction string) { ipv4Instructions = append(ipv4Instructions, instruction) @@ -188,16 +184,26 @@ func (c *Config) AcceptOutputPublicOnlyNewTraffic(ctx context.Context) error { return err } + kernelErr := checkKernelModulesAreOK(c.modules.nfConntrack, + c.modules.xtConntrack, c.modules.xtConnmark) + err = c.runIptablesInstructionsNoSave(ctx, ipv4Instructions) if err != nil { restore(ctx) + if strings.Contains(err.Error(), "support") && kernelErr != nil { + err = fmt.Errorf("%w: %w", err, kernelErr) + } return err } err = c.runIP6tablesInstructionsNoSave(ctx, ipv6Instructions) if err != nil { restore(ctx) + if strings.Contains(err.Error(), "support") && kernelErr != nil { + err = fmt.Errorf("%w: %w", err, kernelErr) + } return err } + return nil } @@ -211,11 +217,6 @@ func (c *Config) RejectOutputPublicTraffic(ctx context.Context, remove bool) err return c.runMixedIptablesInstructions(ctx, removeInstructions) } - err := checkKernelModulesAreOK(c.modules.nfConntrack, c.modules.nfRejectIPv4, c.modules.xtReject) - if err != nil { - return fmt.Errorf("checking kernel modules: %w", err) - } - ipv4Instructions, ipv6Instructions := makeCreatePublicIPChainInstructions() appendToBoth := func(instruction string) { ipv4Instructions = append(ipv4Instructions, instruction) @@ -229,15 +230,26 @@ func (c *Config) RejectOutputPublicTraffic(ctx context.Context, remove bool) err "-j REJECT --reject-with tcp-reset") appendToBoth("-I OUTPUT -j PUBLIC_ONLY") - err = c.runIptablesInstructions(ctx, ipv4Instructions) + kernelErr := checkKernelModulesAreOK(c.modules.nfConntrack, + c.modules.nfRejectIPv4, c.modules.xtReject) + + err := c.runIptablesInstructions(ctx, ipv4Instructions) if err != nil { + if strings.Contains(err.Error(), "support") && kernelErr != nil { + err = fmt.Errorf("%w: %w", err, kernelErr) + } return err } + err = c.runIP6tablesInstructions(ctx, ipv6Instructions) if err != nil { _ = c.runIptablesInstructions(ctx, removeInstructions) + if strings.Contains(err.Error(), "support") && kernelErr != nil { + err = fmt.Errorf("%w: %w", err, kernelErr) + } return err } + return nil } diff --git a/internal/netlink/conntrack_linux.go b/internal/netlink/conntrack_linux.go index 08832077..6b69a2c2 100644 --- a/internal/netlink/conntrack_linux.go +++ b/internal/netlink/conntrack_linux.go @@ -12,12 +12,11 @@ import ( var ErrConntrackNetlinkNotSupported = errors.New("nf_conntrack_netlink is not supported by the kernel") func (n *NetLink) FlushConntrack() error { - if !n.conntrackNetlink { - return fmt.Errorf("%w", ErrConntrackNetlinkNotSupported) - } - conn, err := netfilter.Dial(nil) if err != nil { + if !n.conntrackNetlink { + err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported) + } return fmt.Errorf("dialing netfilter: %w", err) } defer conn.Close() @@ -36,6 +35,9 @@ func (n *NetLink) FlushConntrack() error { _, err = conn.Query(request) if err != nil { + if !n.conntrackNetlink { + err = fmt.Errorf("%w: %w", err, ErrConntrackNetlinkNotSupported) + } return fmt.Errorf("querying netlink request: %w", err) } return nil