chore(firewall): split apart iptables specific code in internal/firewall/iptables

This commit is contained in:
Quentin McGaw
2026-02-25 03:45:17 +00:00
parent 034f8f6331
commit d21953f62e
29 changed files with 209 additions and 103 deletions
+1 -1
View File
@@ -283,7 +283,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
err = printVersions(ctx, logger, []printVersionElement{ err = printVersions(ctx, logger, []printVersionElement{
{name: "Alpine", getVersion: alpineConf.Version}, {name: "Alpine", getVersion: alpineConf.Version},
{name: "OpenVPN", getVersion: ovpnVersion}, {name: "OpenVPN", getVersion: ovpnVersion},
{name: "IPtables", getVersion: firewallConf.Version}, {name: "Firewall", getVersion: firewallConf.Version},
}) })
if err != nil { if err != nil {
return err return err
+23 -23
View File
@@ -42,13 +42,13 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
} }
func (c *Config) disable(ctx context.Context) (err error) { func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil { if err = c.impl.ClearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err) return fmt.Errorf("clearing all rules: %w", err)
} }
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil { if err = c.impl.SetIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err) return fmt.Errorf("setting ipv4 policies: %w", err)
} }
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil { if err = c.impl.SetIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err) return fmt.Errorf("setting ipv6 policies: %w", err)
} }
@@ -72,33 +72,31 @@ func (c *Config) fallbackToDisabled(ctx context.Context) {
} }
func (c *Config) enable(ctx context.Context) (err error) { func (c *Config) enable(ctx context.Context) (err error) {
touched := false if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
touched = true
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
return err return err
} }
const remove = false if err = c.impl.SetIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
defer func() { defer func() {
if touched && err != nil { if err != nil {
c.fallbackToDisabled(ctx) c.fallbackToDisabled(ctx)
} }
}() }()
const remove = false
// Loopback traffic // Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil { if err = c.impl.AcceptInputThroughInterface(ctx, "lo", remove); err != nil {
return err return err
} }
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil { if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err return err
} }
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil { if err = c.impl.AcceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return err return err
} }
@@ -108,7 +106,9 @@ func (c *Config) enable(ctx context.Context) (err error) {
localInterfaces := make(map[string]struct{}, len(c.localNetworks)) localInterfaces := make(map[string]struct{}, len(c.localNetworks))
for _, network := range c.localNetworks { for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.IPNet, remove); err != nil { err = c.impl.AcceptOutputFromIPToSubnet(ctx,
network.InterfaceName, network.IP, network.IPNet, remove)
if err != nil {
return err return err
} }
@@ -117,7 +117,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
continue continue
} }
localInterfaces[network.InterfaceName] = struct{}{} localInterfaces[network.InterfaceName] = struct{}{}
err = c.acceptIpv6MulticastOutput(ctx, network.InterfaceName, remove) err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName, remove)
if err != nil { if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err) return fmt.Errorf("accepting IPv6 multicast output: %w", err)
} }
@@ -130,7 +130,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network // Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun. // to reach Gluetun.
for _, network := range c.localNetworks { for _, network := range c.localNetworks {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil { if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet, remove); err != nil {
return err return err
} }
} }
@@ -144,7 +144,7 @@ func (c *Config) enable(ctx context.Context) (err error) {
return fmt.Errorf("redirecting ports: %w", err) return fmt.Errorf("redirecting ports: %w", err)
} }
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil { if err := c.impl.RunUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err) return fmt.Errorf("running user defined post firewall rules: %w", err)
} }
@@ -164,7 +164,7 @@ func (c *Config) allowVPNIP(ctx context.Context) (err error) {
continue continue
} }
interfacesSeen[defaultRoute.NetInterface] = struct{}{} interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove) err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil { if err != nil {
return fmt.Errorf("accepting output traffic through VPN: %w", err) return fmt.Errorf("accepting output traffic through VPN: %w", err)
} }
@@ -186,7 +186,7 @@ func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
firewallUpdated = true firewallUpdated = true
const remove = false const remove = false
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove) defaultRoute.AssignedIP, subnet, remove)
if err != nil { if err != nil {
return err return err
@@ -204,7 +204,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
for port, netInterfaces := range c.allowedInputPorts { for port, netInterfaces := range c.allowedInputPorts {
for netInterface := range netInterfaces { for netInterface := range netInterfaces {
const remove = false const remove = false
err = c.acceptInputToPort(ctx, netInterface, port, remove) err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
if err != nil { if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w", return fmt.Errorf("accepting input port %d on interface %s: %w",
port, netInterface, err) port, netInterface, err)
@@ -216,7 +216,7 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) { func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
for _, portRedirection := range c.portRedirections { for _, portRedirection := range c.portRedirections {
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort, err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove) portRedirection.destinationPort, remove)
if err != nil { if err != nil {
return err return err
+14 -21
View File
@@ -2,24 +2,23 @@ package firewall
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"sync" "sync"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
) )
type Config struct { type Config struct {
runner CmdRunner runner CmdRunner
logger Logger logger Logger
iptablesMutex sync.Mutex defaultRoutes []routing.DefaultRoute
ip6tablesMutex sync.Mutex localNetworks []routing.LocalNetwork
defaultRoutes []routing.DefaultRoute
localNetworks []routing.LocalNetwork
// Fixed state // Fixed
ipTables string impl firewallImpl
ip6Tables string
customRulesPath string customRulesPath string
// State // State
@@ -38,25 +37,19 @@ func NewConfig(ctx context.Context, logger Logger,
runner CmdRunner, defaultRoutes []routing.DefaultRoute, runner CmdRunner, defaultRoutes []routing.DefaultRoute,
localNetworks []routing.LocalNetwork, localNetworks []routing.LocalNetwork,
) (config *Config, err error) { ) (config *Config, err error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy") impl, err := iptables.New(ctx, runner, logger)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("creating iptables firewall: %w", err)
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
} }
return &Config{ return &Config{
runner: runner, runner: runner,
logger: logger, logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}), allowedInputPorts: make(map[uint16]map[string]struct{}),
ipTables: iptables,
ip6Tables: ip6tables,
customRulesPath: "/iptables/post-rules.txt",
// Obtained from routing // Obtained from routing
defaultRoutes: defaultRoutes, defaultRoutes: defaultRoutes,
localNetworks: localNetworks, localNetworks: localNetworks,
impl: impl,
customRulesPath: "/iptables/post-rules.txt",
}, nil }, nil
} }
+29 -1
View File
@@ -1,6 +1,12 @@
package firewall package firewall
import "os/exec" import (
"context"
"net/netip"
"os/exec"
"github.com/qdm12/gluetun/internal/models"
)
type CmdRunner interface { type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error) Run(cmd *exec.Cmd) (output string, err error)
@@ -12,3 +18,25 @@ type Logger interface {
Warn(s string) Warn(s string)
Error(s string) Error(s string)
} }
type firewallImpl interface { //nolint:interfacebloat
AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error
AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error
AcceptInputToSubnet(ctx context.Context, intf string, subnet netip.Prefix, remove bool) error
AcceptIpv6MulticastOutput(ctx context.Context, intf string, remove bool) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
AcceptOutputTrafficToVPN(ctx context.Context, intf string,
connection models.Connection, remove bool) error
ClearAllRules(ctx context.Context) error
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16, remove bool) error
RunUserPostRules(ctx context.Context, customRulesPath string, remove bool) error
SetIPv4AllPolicies(ctx context.Context, policy string) error
SetIPv6AllPolicies(ctx context.Context, policy string) error
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error)
Version(ctx context.Context) (version string, err error)
}
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"fmt" "fmt"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
+36
View File
@@ -0,0 +1,36 @@
package iptables
import (
"context"
"sync"
)
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
// Fixed state
ipTables string
ip6Tables string
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
}, nil
}
+14
View File
@@ -0,0 +1,14 @@
package iptables
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -14,8 +14,8 @@ import (
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) ( func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
ip6tablesPath string, err error, ip6tablesPath string, err error,
) { ) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy") ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
if errors.Is(err, ErrIPTablesNotSupported) { if errors.Is(err, ErrNotSupported) {
return "", nil return "", nil
} else if err != nil { } else if err != nil {
return "", err return "", err
@@ -56,7 +56,7 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
var ErrPolicyNotValid = errors.New("policy is not valid") var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error { func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -53,7 +53,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
if len(words) < minWords { if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output) return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
} }
return words[1], nil return "iptables " + words[1], nil
} }
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error { func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
@@ -84,7 +84,7 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
return nil return nil
} }
func (c *Config) clearAllRules(ctx context.Context) error { func (c *Config) ClearAllRules(ctx context.Context) error {
tables := []string{"filter"} tables := []string{"filter"}
for _, table := range tables { for _, table := range tables {
return c.runMixedIptablesInstructions(ctx, []string{ return c.runMixedIptablesInstructions(ctx, []string{
@@ -95,7 +95,7 @@ func (c *Config) clearAllRules(ctx context.Context) error {
return nil return nil
} }
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error { func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
@@ -108,13 +108,13 @@ func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
}) })
} }
func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, "%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
)) ))
} }
func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool, destination netip.Prefix, remove bool,
) error { ) error {
interfaceFlag := "-i " + intf interfaceFlag := "-i " + intf
@@ -134,20 +134,20 @@ func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error { func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf, "%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
)) ))
} }
func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runMixedIptablesInstructions(ctx, []string{ return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
}) })
} }
func (c *Config) acceptOutputTrafficToVPN(ctx context.Context, func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool, defaultInterface string, connection models.Connection, remove bool,
) error { ) error {
protocol := connection.Protocol protocol := connection.Protocol
@@ -165,8 +165,11 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
// Thanks to @npawelek. // Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context, func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool, intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error { ) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4() doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
@@ -187,8 +190,11 @@ func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
// NDP uses multicast address (theres no broadcast in IPv6 like ARP uses in IPv4). // AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
func (c *Config) acceptIpv6MulticastOutput(ctx context.Context, // ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool, intf string, remove bool,
) error { ) error {
interfaceFlag := "-o " + intf interfaceFlag := "-o " + intf
@@ -200,8 +206,11 @@ func (c *Config) acceptIpv6MulticastOutput(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
// Used for port forwarding, with intf set to tun. // AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error { // protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces if intf == "*" { // all interfaces
interfaceFlag = "" interfaceFlag = ""
@@ -212,8 +221,12 @@ func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16
}) })
} }
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface. // RedirectPort redirects incoming traffic on the specified source port to the
func (c *Config) redirectPort(ctx context.Context, intf string, // specified destination port, for both TCP and UDP protocols, on the interface intf.
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
// the redirection is removed instead of added. This is used for VPN server side
// port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) RedirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool, sourcePort, destinationPort uint16, remove bool,
) (err error) { ) (err error) {
interfaceFlag := "-i " + intf interfaceFlag := "-i " + intf
@@ -260,7 +273,7 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
return nil return nil
} }
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0) file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"errors" "errors"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"net/netip" "net/netip"
@@ -1,3 +1,3 @@
package firewall package iptables
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger //go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: CmdRunner,Logger) // Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
// Package firewall is a generated GoMock package. // Package iptables is a generated GoMock package.
package firewall package iptables
import ( import (
exec "os/exec" exec "os/exec"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"errors" "errors"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"net/netip" "net/netip"
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -11,10 +11,10 @@ import (
) )
var ( var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing") ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule") ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found") ErrInputPolicyNotFound = errors.New("input policy not found")
ErrIPTablesNotSupported = errors.New("no iptables supported found") ErrNotSupported = errors.New("no iptables supported found")
) )
func checkIptablesSupport(ctx context.Context, runner CmdRunner, func checkIptablesSupport(ctx context.Context, runner CmdRunner,
@@ -57,7 +57,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
} }
return "", fmt.Errorf("%w: errors encountered are: %s", return "", fmt.Errorf("%w: errors encountered are: %s",
ErrIPTablesNotSupported, strings.Join(allUnsupportedMessages, "; ")) ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
} }
func testIptablesPath(ctx context.Context, path string, func testIptablesPath(ctx context.Context, path string,
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
@@ -101,7 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner return runner
}, },
iptablesPathsToTry: []string{"path1", "path2"}, iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrIPTablesNotSupported, errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " + errMessage: "no iptables supported found: " +
"errors encountered are: " + "errors encountered are: " +
"path1: output 1 (exit code 4); " + "path1: output 1 (exit code 4); " +
@@ -1,4 +1,4 @@
package firewall package iptables
import ( import (
"context" "context"
+2 -2
View File
@@ -48,7 +48,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Pref
} }
firewallUpdated = true firewallUpdated = true
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subNet, remove) defaultRoute.AssignedIP, subNet, remove)
if err != nil { if err != nil {
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error()) c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
@@ -77,7 +77,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []netip.Prefix)
} }
firewallUpdated = true firewallUpdated = true
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove) defaultRoute.AssignedIP, subnet, remove)
if err != nil { if err != nil {
return err return err
+2 -2
View File
@@ -35,7 +35,7 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...") c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
const remove = false const remove = false
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { if err := c.impl.AcceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w", return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err) port, intf, err)
} }
@@ -68,7 +68,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
const remove = true const remove = true
for netInterface := range interfacesSet { for netInterface := range interfacesSet {
err := c.acceptInputToPort(ctx, netInterface, port, remove) err := c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
if err != nil { if err != nil {
return fmt.Errorf("removing allowed port %d on interface %s: %w", return fmt.Errorf("removing allowed port %d on interface %s: %w",
port, netInterface, err) port, netInterface, err)
+2 -2
View File
@@ -50,7 +50,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
return nil return nil
case conflict != nil: case conflict != nil:
const remove = true const remove = true
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort, err = c.impl.RedirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove) conflict.destinationPort, remove)
if err != nil { if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err) return fmt.Errorf("removing conflicting redirection: %w", err)
@@ -60,7 +60,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
} }
const remove = false const remove = false
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove) err = c.impl.RedirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil { if err != nil {
return fmt.Errorf("redirecting port: %w", err) return fmt.Errorf("redirecting port: %w", err)
} }
+4 -4
View File
@@ -28,7 +28,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove := true remove := true
if c.vpnConnection.IP.IsValid() { if c.vpnConnection.IP.IsValid() {
for _, defaultRoute := range c.defaultRoutes { for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil { if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error()) c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
} }
} }
@@ -36,7 +36,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
c.vpnConnection = models.Connection{} c.vpnConnection = models.Connection{}
if c.vpnIntf != "" { if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { if err = c.impl.AcceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error()) c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
} }
} }
@@ -45,13 +45,13 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove = false remove = false
for _, defaultRoute := range c.defaultRoutes { for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil { if err := c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
return fmt.Errorf("allowing output traffic through VPN connection: %w", err) return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
} }
} }
c.vpnConnection = connection c.vpnConnection = connection
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil { if err = c.impl.AcceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err) return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
} }
c.vpnIntf = vpnIntf c.vpnIntf = vpnIntf
+21
View File
@@ -0,0 +1,21 @@
package firewall
import (
"context"
"net/netip"
)
func (c *Config) Version(ctx context.Context) (version string, err error) {
return c.impl.Version(ctx)
}
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
return c.impl.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
}
+2 -2
View File
@@ -7,7 +7,7 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/icmp" "github.com/qdm12/gluetun/internal/pmtud/icmp"
"github.com/qdm12/gluetun/internal/pmtud/tcp" "github.com/qdm12/gluetun/internal/pmtud/tcp"
@@ -71,7 +71,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
} }
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger) mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil { if err != nil {
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) { if errors.Is(err, iptables.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err) logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess { if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results return maxPossibleMTU, nil // only rely on ICMP PMTUD results
+2 -1
View File
@@ -8,6 +8,7 @@ import (
"github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
@@ -35,7 +36,7 @@ func getFirewall(t *testing.T) *firewall.Config {
cmder := command.New() cmder := command.New()
var err error var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) { if errors.Is(err, iptables.ErrNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests") t.Skip("iptables not installed, skipping TCP PMTUD tests")
} }
require.NoError(t, err, "creating firewall config") require.NoError(t, err, "creating firewall config")
+2 -2
View File
@@ -7,7 +7,7 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/ip"
) )
@@ -43,7 +43,7 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
if result.err != nil { if result.err != nil {
switch { switch {
case err != nil: // error already occurred for another findMSS goroutine case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, firewall.ErrMarkMatchModuleMissing): case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err) err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable): case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable):
// silently discard IPv6 network unreachable errors since they are common // silently discard IPv6 network unreachable errors since they are common