chore(firewall): split apart iptables specific code in internal/firewall/iptables

This commit is contained in:
Quentin McGaw
2026-02-25 03:45:17 +00:00
parent 034f8f6331
commit d21953f62e
29 changed files with 209 additions and 103 deletions
@@ -0,0 +1,61 @@
package iptables
import (
"fmt"
"os/exec"
"regexp"
"github.com/golang/mock/gomock"
)
var _ gomock.Matcher = (*cmdMatcher)(nil)
type cmdMatcher struct {
path string
argsRegex []string
argsRegexp []*regexp.Regexp
}
func (cm *cmdMatcher) Matches(x interface{}) bool {
cmd, ok := x.(*exec.Cmd)
if !ok {
return false
}
if cmd.Path != cm.path {
return false
}
if len(cmd.Args) == 0 {
return false
}
arguments := cmd.Args[1:]
if len(arguments) != len(cm.argsRegex) {
return false
}
for i, arg := range arguments {
if !cm.argsRegexp[i].MatchString(arg) {
return false
}
}
return true
}
func (cm *cmdMatcher) String() string {
return fmt.Sprintf("path %s, argument regular expressions %v", cm.path, cm.argsRegex)
}
func newCmdMatcher(path string, argsRegex ...string) *cmdMatcher {
argsRegexp := make([]*regexp.Regexp, len(argsRegex))
for i, argRegex := range argsRegex {
argsRegexp[i] = regexp.MustCompile(argRegex)
}
return &cmdMatcher{
path: path,
argsRegex: argsRegex,
argsRegexp: argsRegexp,
}
}
+102
View File
@@ -0,0 +1,102 @@
package iptables
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
)
// isDeleteMatchInstruction returns true if the iptables instruction
// is a delete instruction by rule matching. It returns false if the
// instruction is a delete instruction by line number, or not a delete
// instruction.
func isDeleteMatchInstruction(instruction string) bool {
fields := strings.Fields(instruction)
for i, field := range fields {
switch {
case field != "-D" && field != "--delete":
continue
case i == len(fields)-1: // malformed: missing chain name
return false
case i == len(fields)-2: // chain name is last field
return true
default:
// chain name is fields[i+1]
const base, bitLength = 10, 16
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
return err != nil // not a line number
}
}
return false
}
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
runner CmdRunner, logger Logger,
) (err error) {
targetRule, err := parseIptablesInstruction(instruction)
if err != nil {
return fmt.Errorf("parsing iptables command: %w", err)
}
lineNumber, err := findLineNumber(ctx, iptablesBinary,
targetRule, runner, logger)
if err != nil {
return fmt.Errorf("finding iptables chain rule line number: %w", err)
} else if lineNumber == 0 {
logger.Debug("rule matching \"" + instruction + "\" not found")
return nil
}
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
instruction, lineNumber))
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return err
}
return nil
}
// findLineNumber finds the line number of an iptables rule.
// It returns 0 if the rule is not found.
func findLineNumber(ctx context.Context, iptablesBinary string,
instruction iptablesInstruction, runner CmdRunner, logger Logger) (
lineNumber uint16, err error,
) {
listFlags := []string{
"-t", instruction.table, "-L", instruction.chain,
"--line-numbers", "-n", "-v",
}
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return 0, err
}
chain, err := parseChain(output)
if err != nil {
return 0, fmt.Errorf("parsing chain list: %w", err)
}
for _, rule := range chain.rules {
if instruction.equalToRule(instruction.table, chain.name, rule) {
return rule.lineNumber, nil
}
}
return 0, nil
}
+186
View File
@@ -0,0 +1,186 @@
package iptables
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_isDeleteMatchInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
instruction string
isDeleteMatch bool
}{
"not_delete": {
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
},
"malformed_missing_chain_name": {
instruction: "-t nat -D",
},
"delete_chain_name_last_field": {
instruction: "-t nat --delete PREROUTING",
isDeleteMatch: true,
},
"delete_match": {
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
isDeleteMatch: true,
},
"delete_line_number_last_field": {
instruction: "-t nat -D PREROUTING 2",
},
"delete_line_number": {
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
})
}
}
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
"^--line-numbers$", "^-n$", "^-v$")
}
func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
"rule_not_found": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
num pkts bytes target prot opt in out source destination
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
return logger
},
},
"rule_found_delete_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
ctx := context.Background()
instruction := testCase.instruction
var runner *MockCmdRunner
if testCase.makeRunner != nil {
runner = testCase.makeRunner(ctrl)
}
var logger *MockLogger
if testCase.makeLogger != nil {
logger = testCase.makeLogger(ctrl)
}
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
+36
View File
@@ -0,0 +1,36 @@
package iptables
import (
"context"
"sync"
)
type Config struct {
runner CmdRunner
logger Logger
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
// Fixed state
ipTables string
ip6Tables string
}
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
if err != nil {
return nil, err
}
ip6tables, err := findIP6tablesSupported(ctx, runner)
if err != nil {
return nil, err
}
return &Config{
runner: runner,
logger: logger,
ipTables: iptables,
ip6Tables: ip6tables,
}, nil
}
+14
View File
@@ -0,0 +1,14 @@
package iptables
import "os/exec"
type CmdRunner interface {
Run(cmd *exec.Cmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
+70
View File
@@ -0,0 +1,70 @@
package iptables
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
)
// findIP6tablesSupported checks for multiple iptables implementations
// and returns the iptables path that is supported. If none work, an
// empty string path is returned.
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
ip6tablesPath string, err error,
) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
if errors.Is(err, ErrNotSupported) {
return "", nil
} else if err != nil {
return "", err
}
return ip6tablesPath, nil
}
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
if c.ip6Tables == "" {
return nil
}
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
}
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
+347
View File
@@ -0,0 +1,347 @@
package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
"os"
"os/exec"
"strings"
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
)
func appendOrDelete(remove bool) string {
if remove {
return "--delete"
}
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
fields := strings.Fields(rule)
for i, field := range fields {
switch field {
case "-A", "--append":
fields[i] = "--delete"
case "-D", "--delete":
fields[i] = "--append"
}
}
return strings.Join(fields, " ")
}
// Version obtains the version of the installed iptables.
func (c *Config) Version(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
output, err := c.runner.Run(cmd)
if err != nil {
return "", err
}
words := strings.Fields(output)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
}
return "iptables " + words[1], nil
}
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
}
return nil
}
func (c *Config) ClearAllRules(ctx context.Context) error {
tables := []string{"filter"}
for _, table := range tables {
return c.runMixedIptablesInstructions(ctx, []string{
"-t " + table + " --flush", // flush all chains
"-t " + table + " --delete-chain", // delete all chains
})
}
return nil
}
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
}
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
"--policy OUTPUT " + policy,
"--policy FORWARD " + policy,
})
}
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool,
) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destination.String())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
})
}
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool,
) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
protocol, connection.Port)
if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
// Thanks to @npawelek.
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
) error {
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
if doIPv4 {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
// all interfaces. If remove is true, the rule is removed instead of added.
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context,
intf string, remove bool,
) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
appendOrDelete(remove), interfaceFlag)
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
// intf set to the VPN tunnel interface.
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runMixedIptablesInstructions(ctx, []string{
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
})
}
// RedirectPort redirects incoming traffic on the specified source port to the
// specified destination port, for both TCP and UDP protocols, on the interface intf.
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
// the redirection is removed instead of added. This is used for VPN server side
// port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) RedirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool,
) (err error) {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
}
return nil
}
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
return nil
}
func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
b, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()
for _, line := range lines {
var ipv4 bool
var rule string
switch {
case strings.HasPrefix(line, "iptables "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables ")
case strings.HasPrefix(line, "iptables-nft "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-nft ")
case strings.HasPrefix(line, "iptables-legacy "):
ipv4 = true
rule = strings.TrimPrefix(line, "iptables-legacy ")
case strings.HasPrefix(line, "ip6tables "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables ")
case strings.HasPrefix(line, "ip6tables-nft "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-nft ")
case strings.HasPrefix(line, "ip6tables-legacy "):
ipv4 = false
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
default:
continue
}
if remove {
rule = flipRule(rule)
}
switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
default: // ipv6
err = c.runIP6tablesInstruction(ctx, rule)
}
if err != nil {
return err
}
successfulRules = append(successfulRules, rule)
}
return nil
}
+21
View File
@@ -0,0 +1,21 @@
package iptables
import (
"context"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
return c.runIP6tablesInstruction(ctx, instruction)
}
+523
View File
@@ -0,0 +1,523 @@
package iptables
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type chain struct {
name string
policy string
packets uint64
bytes uint64
rules []chainRule
}
type chainRule struct {
lineNumber uint16 // starts from 1 and cannot be zero.
packets uint64
bytes uint64
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
protocol string // "icmp", "tcp", "udp" or "" for all protocols.
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
sourcePort uint16 // Not specified if set to zero.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
mark mark
}
type mark struct {
invert bool
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// pkts bytes target prot opt in out source destination
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
iptablesOutput = strings.TrimSpace(iptablesOutput)
linesWithComments := strings.Split(iptablesOutput, "\n")
// Filter out lines starting with a '#' character
lines := make([]string, 0, len(linesWithComments))
for _, line := range linesWithComments {
if strings.HasPrefix(line, "#") {
continue
}
lines = append(lines, line)
}
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
if err != nil {
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
}
// Sanity check for the legend line
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
if len(lines) == 0 {
return c, nil
}
c.rules = make([]chainRule, len(lines))
for i, line := range lines {
c.rules[i], err = parseChainRuleLine(line)
if err != nil {
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
}
}
return c, nil
}
// parseChainGeneralDataLine parses the first line of iptables chain list output.
// For example, it can parse the following line:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// It returns a chain struct with the parsed data.
func parseChainGeneralDataLine(line string) (base chain, err error) {
line = strings.TrimSpace(line)
runesToRemove := []rune{'(', ')', ','}
for _, r := range runesToRemove {
line = strings.ReplaceAll(line, string(r), "")
}
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
}
// Sanity checks
indexToExpectedValue := map[int]string{
0: "Chain",
2: "policy",
5: "packets",
7: "bytes",
}
for index, expectedValue := range indexToExpectedValue {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
base.policy = fields[3]
err = checkTarget(base.policy)
if err != nil {
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
}
packets, err := parseMetricSize(fields[4])
if err != nil {
return chain{}, fmt.Errorf("parsing packets: %w", err)
}
base.packets = packets
bytes, err := parseMetricSize(fields[6])
if err != nil {
return chain{}, fmt.Errorf("parsing bytes: %w", err)
}
base.bytes = bytes
return base, nil
}
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
}
for fieldIndex, field := range fields[:minFields] {
err = parseChainRuleField(fieldIndex, field, &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
}
}
if len(fields) > minFields {
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
}
}
return rule, nil
}
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
}
const (
numIndex = iota
packetsIndex
bytesIndex
targetIndex
protocolIndex
optIndex
inputInterfaceIndex
outputInterfaceIndex
sourceIndex
destinationIndex
)
switch fieldIndex {
case numIndex:
rule.lineNumber, err = parseLineNumber(field)
if err != nil {
return fmt.Errorf("parsing line number: %w", err)
}
case packetsIndex:
rule.packets, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing packets: %w", err)
}
case bytesIndex:
rule.bytes, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
if err != nil {
return fmt.Errorf("parsing protocol: %w", err)
}
case optIndex: // ignored
case inputInterfaceIndex:
rule.inputInterface = field
case outputInterfaceIndex:
rule.outputInterface = field
case sourceIndex:
rule.source, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case destinationIndex:
rule.destination, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
}
return nil
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
i := 0
for i < len(optionalFields) {
switch optionalFields[i] {
case "udp":
i++
consumed, err := parseUDPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing UDP optional fields: %w", err)
}
i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
case "redir":
i++
switch optionalFields[i] {
case "ports":
i++
ports, err := parsePortsCSV(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
i++
case "mark":
i++
mark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing mark: %w", err)
}
rule.mark = mark
i += consumed
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "spt:"):
rule.sourcePort, err = parseSourcePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
return parsePort(value)
}
func parseSourcePort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "spt:")
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
}
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
ports[i], err = parsePort(field)
if err != nil {
return nil, err
}
}
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
consumed++
if optionalFields[consumed] == "!" {
m.invert = true
consumed++
}
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
}
return uint16(lineNumber), nil
}
var ErrTargetUnknown = errors.New("unknown target")
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
}
var ErrProtocolUnknown = errors.New("unknown protocol")
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0", "all":
case "1", "icmp":
protocol = "icmp"
case "6", "tcp":
protocol = "tcp"
case "17", "udp":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
}
return protocol, nil
}
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
}
//nolint:mnd
multiplerLetterToValue := map[byte]uint64{
'K': 1000,
'M': 1000000,
'G': 1000000000,
'T': 1000000000000,
}
lastCharacter := size[len(size)-1]
multiplier, ok := multiplerLetterToValue[lastCharacter]
if ok { // multiplier present
size = size[:len(size)-1]
} else {
multiplier = 1
}
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
}
n *= multiplier
return n, nil
}
+144
View File
@@ -0,0 +1,144 @@
package iptables
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseChain(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
},
"no_rule": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
},
},
"some_rules": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
3 0 0 ACCEPT 1 -- tun0 * 0.0.0.0/0 0.0.0.0/0
4 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
5 0 0 ACCEPT all -- tun0 * 1.2.3.4 0.0.0.0/0
`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
rules: []chainRule{
{
lineNumber: 1,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "udp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 2,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "tcp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 3,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "icmp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
{
lineNumber: 4,
packets: 0,
bytes: 0,
target: "DROP",
protocol: "",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
{
lineNumber: 5,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
@@ -0,0 +1,3 @@
package iptables
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
+121
View File
@@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
// Package iptables is a generated GoMock package.
package iptables
import (
exec "os/exec"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockCmdRunner is a mock of CmdRunner interface.
type MockCmdRunner struct {
ctrl *gomock.Controller
recorder *MockCmdRunnerMockRecorder
}
// MockCmdRunnerMockRecorder is the mock recorder for MockCmdRunner.
type MockCmdRunnerMockRecorder struct {
mock *MockCmdRunner
}
// NewMockCmdRunner creates a new mock instance.
func NewMockCmdRunner(ctrl *gomock.Controller) *MockCmdRunner {
mock := &MockCmdRunner{ctrl: ctrl}
mock.recorder = &MockCmdRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCmdRunner) EXPECT() *MockCmdRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockCmdRunner) Run(arg0 *exec.Cmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockCmdRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCmdRunner)(nil).Run), arg0)
}
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}
// Warn mocks base method.
func (m *MockLogger) Warn(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Warn", arg0)
}
// Warn indicates an expected call of Warn.
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
}
+262
View File
@@ -0,0 +1,262 @@
package iptables
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
sourcePort uint16 // if zero, there is no source port
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
mark mark
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case i.sourcePort != rule.sourcePort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
case i.mark != rule.mark:
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
consumed, err = preCheckInstructionFields(fields)
if err != nil {
return 0, err
}
flag := fields[0]
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.chain = value
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match":
consumed, err = parseMatchModule(fields, instruction)
if err != nil {
return 0, fmt.Errorf("parsing match module: %w", err)
}
case "--mark":
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
}
instruction.mark.value = uint(value)
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "--sport":
instruction.sourcePort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing source port: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
instruction.destinationPort, err = parsePort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
instruction.toPorts, err = parseToPorts(value)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
}
return consumed, nil
}
func preCheckInstructionFields(fields []string) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
return expected, nil
}
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}
func parsePort(value string) (port uint16, err error) {
const base, bitLength = 10, 16
portValue, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, err
}
return uint16(portValue), nil
}
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed int, err error,
) {
_ = fields[consumed] // -m or --match flag already detected
consumed++
switch fields[consumed] {
case "tcp", "udp":
consumed++
// for now ignore the protocol match since it's auto-loaded
// when parsing the -p/--protocol flag, and we don't need to
// parse it twice.
case "mark":
consumed++
switch fields[consumed] {
case "!":
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
}
return consumed, nil
}
func parseToPorts(value string) (toPorts []uint16, err error) {
portStrings := strings.Split(value, ",")
toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
toPorts[i], err = parsePort(portString)
if err != nil {
return nil, err
}
}
return toPorts, nil
}
+136
View File
@@ -0,0 +1,136 @@
package iptables
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseIptablesInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("5.6.7.8/32"),
destinationPort: 10000,
target: "ACCEPT",
},
},
"nat_redirection": {
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
target: "REDIRECT",
toPorts: []uint16{5678},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
func Test_parseIPPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
value string
prefix netip.Prefix
errMessage string
}{
"empty": {
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
},
"invalid": {
value: "invalid",
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
},
"valid_ipv4_with_bits": {
value: "10.0.0.0/16",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
},
"valid_ipv4_without_bits": {
value: "10.0.0.4",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
},
"valid_ipv6_with_bits": {
value: "2001:db8::/32",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
32),
},
"valid_ipv6_without_bits": {
value: "2001:db8::",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
128),
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix, err := parseIPPrefix(testCase.value)
assert.Equal(t, testCase.prefix, prefix)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}
+162
View File
@@ -0,0 +1,162 @@
package iptables
import (
"context"
"errors"
"fmt"
"math/rand"
"os/exec"
"sort"
"strings"
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
)
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
iptablesPathsToTry ...string,
) (iptablesPath string, err error) {
iptablesPathToUnsupportedMessage := make(map[string]string, len(iptablesPathsToTry))
for _, pathToTest := range iptablesPathsToTry {
ok, unsupportedMessage, err := testIptablesPath(ctx, pathToTest, runner)
if err != nil {
return "", fmt.Errorf("for %s: %w", pathToTest, err)
} else if ok {
iptablesPath = pathToTest
break
}
iptablesPathToUnsupportedMessage[pathToTest] = unsupportedMessage
}
if iptablesPath != "" {
// some paths may be unsupported but that does not matter
// since we found one working.
return iptablesPath, nil
}
allArePermissionDenied := true
allUnsupportedMessages := make(sort.StringSlice, 0, len(iptablesPathToUnsupportedMessage))
for iptablesPath, unsupportedMessage := range iptablesPathToUnsupportedMessage {
if !isPermissionDenied(unsupportedMessage) {
allArePermissionDenied = false
}
unsupportedMessage = iptablesPath + ": " + unsupportedMessage
allUnsupportedMessages = append(allUnsupportedMessages, unsupportedMessage)
}
allUnsupportedMessages.Sort() // predictable order for tests
if allArePermissionDenied {
// If the error is related to a denied permission for all iptables path,
// return an error describing what to do from an end-user perspective.
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
}
return "", fmt.Errorf("%w: errors encountered are: %s",
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
}
func testIptablesPath(ctx context.Context, path string,
runner CmdRunner) (ok bool, unsupportedMessage string,
criticalErr error,
) {
// Just listing iptables rules often work but we need
// to modify them to ensure we can support the iptables
// being tested.
// Append a test rule with a random interface name to the OUTPUT table.
// This should not affect existing rules or the network traffic.
testInterfaceName := randomInterfaceName()
cmd := exec.CommandContext(ctx, path,
"-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
output, err := runner.Run(cmd)
if err != nil {
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
return false, unsupportedMessage, nil
}
// Remove the random rule added previously for test.
cmd = exec.CommandContext(ctx, path,
"-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
output, err = runner.Run(cmd)
if err != nil {
// this is a critical error, we want to make sure our test rule gets removed.
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
return false, "", criticalErr
}
// Set policy as the existing policy so no mutation is done.
// This is an extra check for some buggy kernels where setting the policy
// does not work.
cmd = exec.CommandContext(ctx, path, "-nL", "INPUT")
output, err = runner.Run(cmd)
if err != nil {
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
return false, unsupportedMessage, nil
}
var inputPolicy string
for _, line := range strings.Split(output, "\n") {
inputPolicy, ok = extractInputPolicy(line)
if ok {
break
}
}
if inputPolicy == "" {
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
return false, "", criticalErr
}
// Set the policy for the INPUT table to the existing policy found.
cmd = exec.CommandContext(ctx, path, "--policy", "INPUT", inputPolicy)
output, err = runner.Run(cmd)
if err != nil {
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
return false, unsupportedMessage, nil
}
return true, "", nil // success
}
func isPermissionDenied(errMessage string) (ok bool) {
const permissionDeniedString = "Permission denied (you must be root)"
return strings.Contains(errMessage, permissionDeniedString)
}
func extractInputPolicy(line string) (policy string, ok bool) {
const prefixToFind = "Chain INPUT (policy "
i := strings.Index(line, prefixToFind)
if i == -1 {
return "", false
}
startIndex := i + len(prefixToFind)
endIndex := strings.Index(line, ")")
if endIndex < 0 {
return "", false
}
policy = line[startIndex:endIndex]
policy = strings.TrimSpace(policy)
if policy == "" {
return "", false
}
return policy, true
}
func randomInterfaceName() (interfaceName string) {
const size = 15
letterRunes := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
b := make([]rune, size)
for i := range b {
letterIndex := rand.Intn(len(letterRunes)) //nolint:gosec
b[i] = letterRunes[letterIndex]
}
return string(b)
}
+345
View File
@@ -0,0 +1,345 @@
package iptables
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newAppendTestRuleMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
"^-j$", "^DROP$")
}
func newDeleteTestRuleMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-D$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
"^-j$", "^DROP$")
}
func newListInputRulesMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-nL$", "^INPUT$")
}
func newSetPolicyMatcher(path, inputPolicy string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(path,
"^--policy$", "^INPUT$", "^"+inputPolicy+"$")
}
func Test_checkIptablesSupport(t *testing.T) {
t.Parallel()
ctx := context.Background()
errDummy := errors.New("exit code 4")
const inputPolicy = "ACCEPT"
testCases := map[string]struct {
buildRunner func(ctrl *gomock.Controller) CmdRunner
iptablesPathsToTry []string
iptablesPath string
errSentinel error
errMessage string
}{
"critical error when checking": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")).
Return("output", errDummy)
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrTestRuleCleanup,
errMessage: "for path1: failed cleaning up test rule: " +
"output (exit code 4)",
},
"found valid path": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")).
Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher("path1")).
Return("Chain INPUT (policy "+inputPolicy+")", nil)
runner.EXPECT().Run(newSetPolicyMatcher("path1", inputPolicy)).
Return("", nil)
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
iptablesPath: "path1",
},
"all permission denied": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
Return("Permission denied (you must be root) more context", errDummy)
runner.EXPECT().Run(newAppendTestRuleMatcher("path2")).
Return("context: Permission denied (you must be root)", errDummy)
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNetAdminMissing,
errMessage: "NET_ADMIN capability is missing: " +
"path1: Permission denied (you must be root) more context (exit code 4); " +
"path2: context: Permission denied (you must be root) (exit code 4)",
},
"no valid path": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
Return("output 1", errDummy)
runner.EXPECT().Run(newAppendTestRuleMatcher("path2")).
Return("output 2", errDummy)
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
"path2: output 2 (exit code 4)",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
runner := testCase.buildRunner(ctrl)
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
require.ErrorIs(t, err, testCase.errSentinel)
if testCase.errSentinel != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.iptablesPath, iptablesPath)
})
}
}
func Test_testIptablesPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
const path = "dummypath"
errDummy := errors.New("exit code 4")
const inputPolicy = "ACCEPT"
testCases := map[string]struct {
buildRunner func(ctrl *gomock.Controller) CmdRunner
ok bool
unsupportedMessage string
criticalErrWrapped error
criticalErrMessage string
}{
"append test rule permission denied": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).
Return("Permission denied (you must be root)", errDummy)
return runner
},
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
},
"append test rule unsupported": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).
Return("some output", errDummy)
return runner
},
unsupportedMessage: "some output (exit code 4)",
},
"remove test rule error": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).
Return("some output", errDummy)
return runner
},
criticalErrWrapped: ErrTestRuleCleanup,
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
},
"list input rules permission denied": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("Permission denied (you must be root)", errDummy)
return runner
},
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
},
"list input rules unsupported": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("some output", errDummy)
return runner
},
unsupportedMessage: "some output (exit code 4)",
},
"list input rules no policy": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("some\noutput", nil)
return runner
},
criticalErrWrapped: ErrInputPolicyNotFound,
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
},
"set policy permission denied": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("\nChain INPUT (policy "+inputPolicy+")\nAA\n", nil)
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
Return("Permission denied (you must be root)", errDummy)
return runner
},
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
},
"set policy unsupported": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("\nChain INPUT (policy "+inputPolicy+")\nBB\n", nil)
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
Return("some output", errDummy)
return runner
},
unsupportedMessage: "some output (exit code 4)",
},
"success": {
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
runner.EXPECT().Run(newListInputRulesMatcher(path)).
Return("\nChain INPUT (policy "+inputPolicy+")\nCC\n", nil)
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
Return("some output", nil)
return runner
},
ok: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
runner := testCase.buildRunner(ctrl)
ok, unsupportedMessage, criticalErr := testIptablesPath(ctx, path, runner)
assert.Equal(t, testCase.ok, ok)
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
if testCase.criticalErrWrapped != nil {
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
}
})
}
}
func Test_isPermissionDenied(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
errMessage string
ok bool
}{
"empty error": {},
"other error": {
errMessage: "some error",
},
"permission denied": {
errMessage: "Permission denied (you must be root) have you tried blabla",
ok: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ok := isPermissionDenied(testCase.errMessage)
assert.Equal(t, testCase.ok, ok)
})
}
}
func Test_extractInputPolicy(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
line string
policy string
ok bool
}{
"empty line": {},
"random line": {
line: "random line",
},
"only first part": {
line: "Chain INPUT (policy ",
},
"empty policy": {
line: "Chain INPUT (policy )",
},
"ACCEPT policy": {
line: "Chain INPUT (policy ACCEPT)",
policy: "ACCEPT",
ok: true,
},
"ACCEPT policy with surrounding garbage": {
line: "garbage Chain INPUT (policy ACCEPT\t) )g()arbage",
policy: "ACCEPT",
ok: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
policy, ok := extractInputPolicy(testCase.line)
assert.Equal(t, testCase.policy, policy)
assert.Equal(t, testCase.ok, ok)
})
}
}
func Test_randomInterfaceName(t *testing.T) {
t.Parallel()
const expectedRegex = `^[a-z0-9]{15}$`
interfaceName := randomInterfaceName()
assert.Regexp(t, expectedRegex, interfaceName)
}
+98
View File
@@ -0,0 +1,98 @@
package iptables
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
src, dst netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
if err != nil && errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
run := c.runIptablesInstruction
if dst.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}