mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-09 20:29:23 +02:00
chore(firewall): split apart iptables specific code in internal/firewall/iptables
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
var _ gomock.Matcher = (*cmdMatcher)(nil)
|
||||
|
||||
type cmdMatcher struct {
|
||||
path string
|
||||
argsRegex []string
|
||||
argsRegexp []*regexp.Regexp
|
||||
}
|
||||
|
||||
func (cm *cmdMatcher) Matches(x interface{}) bool {
|
||||
cmd, ok := x.(*exec.Cmd)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if cmd.Path != cm.path {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(cmd.Args) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
arguments := cmd.Args[1:]
|
||||
if len(arguments) != len(cm.argsRegex) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, arg := range arguments {
|
||||
if !cm.argsRegexp[i].MatchString(arg) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *cmdMatcher) String() string {
|
||||
return fmt.Sprintf("path %s, argument regular expressions %v", cm.path, cm.argsRegex)
|
||||
}
|
||||
|
||||
func newCmdMatcher(path string, argsRegex ...string) *cmdMatcher {
|
||||
argsRegexp := make([]*regexp.Regexp, len(argsRegex))
|
||||
for i, argRegex := range argsRegex {
|
||||
argsRegexp[i] = regexp.MustCompile(argRegex)
|
||||
}
|
||||
return &cmdMatcher{
|
||||
path: path,
|
||||
argsRegex: argsRegex,
|
||||
argsRegexp: argsRegexp,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isDeleteMatchInstruction returns true if the iptables instruction
|
||||
// is a delete instruction by rule matching. It returns false if the
|
||||
// instruction is a delete instruction by line number, or not a delete
|
||||
// instruction.
|
||||
func isDeleteMatchInstruction(instruction string) bool {
|
||||
fields := strings.Fields(instruction)
|
||||
for i, field := range fields {
|
||||
switch {
|
||||
case field != "-D" && field != "--delete":
|
||||
continue
|
||||
case i == len(fields)-1: // malformed: missing chain name
|
||||
return false
|
||||
case i == len(fields)-2: // chain name is last field
|
||||
return true
|
||||
default:
|
||||
// chain name is fields[i+1]
|
||||
const base, bitLength = 10, 16
|
||||
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
|
||||
return err != nil // not a line number
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
|
||||
runner CmdRunner, logger Logger,
|
||||
) (err error) {
|
||||
targetRule, err := parseIptablesInstruction(instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, iptablesBinary,
|
||||
targetRule, runner, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding iptables chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
logger.Debug("rule matching \"" + instruction + "\" not found")
|
||||
return nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
|
||||
instruction, lineNumber))
|
||||
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
|
||||
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findLineNumber finds the line number of an iptables rule.
|
||||
// It returns 0 if the rule is not found.
|
||||
func findLineNumber(ctx context.Context, iptablesBinary string,
|
||||
instruction iptablesInstruction, runner CmdRunner, logger Logger) (
|
||||
lineNumber uint16, err error,
|
||||
) {
|
||||
listFlags := []string{
|
||||
"-t", instruction.table, "-L", instruction.chain,
|
||||
"--line-numbers", "-n", "-v",
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
chain, err := parseChain(output)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing chain list: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range chain.rules {
|
||||
if instruction.equalToRule(instruction.table, chain.name, rule) {
|
||||
return rule.lineNumber, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_isDeleteMatchInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
isDeleteMatch bool
|
||||
}{
|
||||
"not_delete": {
|
||||
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
|
||||
},
|
||||
"malformed_missing_chain_name": {
|
||||
instruction: "-t nat -D",
|
||||
},
|
||||
"delete_chain_name_last_field": {
|
||||
instruction: "-t nat --delete PREROUTING",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_match": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_line_number_last_field": {
|
||||
instruction: "-t nat -D PREROUTING 2",
|
||||
},
|
||||
"delete_line_number": {
|
||||
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
|
||||
|
||||
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
|
||||
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
|
||||
"^--line-numbers$", "^-n$", "^-v$")
|
||||
}
|
||||
|
||||
func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
"rule_not_found": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
|
||||
nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
"rule_found_delete_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockCmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
instruction := testCase.instruction
|
||||
var runner *MockCmdRunner
|
||||
if testCase.makeRunner != nil {
|
||||
runner = testCase.makeRunner(ctrl)
|
||||
}
|
||||
var logger *MockLogger
|
||||
if testCase.makeLogger != nil {
|
||||
logger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
runner CmdRunner
|
||||
logger Logger
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
|
||||
// Fixed state
|
||||
ipTables string
|
||||
ip6Tables string
|
||||
}
|
||||
|
||||
func New(ctx context.Context, runner CmdRunner, logger Logger) (*Config, error) {
|
||||
iptables, err := checkIptablesSupport(ctx, runner, "iptables", "iptables-nft", "iptables-legacy")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip6tables, err := findIP6tablesSupported(ctx, runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Config{
|
||||
runner: runner,
|
||||
logger: logger,
|
||||
ipTables: iptables,
|
||||
ip6Tables: ip6tables,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package iptables
|
||||
|
||||
import "os/exec"
|
||||
|
||||
type CmdRunner interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// findIP6tablesSupported checks for multiple iptables implementations
|
||||
// and returns the iptables path that is supported. If none work, an
|
||||
// empty string path is returned.
|
||||
func findIP6tablesSupported(ctx context.Context, runner CmdRunner) (
|
||||
ip6tablesPath string, err error,
|
||||
) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-legacy")
|
||||
if errors.Is(err, ErrNotSupported) {
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ip6tablesPath, nil
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIP6tablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error {
|
||||
if c.ip6Tables == "" {
|
||||
return nil
|
||||
}
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ip6Tables, instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,347 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
if remove {
|
||||
return "--delete"
|
||||
}
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
fields := strings.Fields(rule)
|
||||
for i, field := range fields {
|
||||
switch field {
|
||||
case "-A", "--append":
|
||||
fields[i] = "--delete"
|
||||
case "-D", "--delete":
|
||||
fields[i] = "--append"
|
||||
}
|
||||
}
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, "--version") //nolint:gosec
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) ClearAllRules(ctx context.Context) error {
|
||||
tables := []string{"filter"}
|
||||
for _, table := range tables {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"-t " + table + " --flush", // flush all chains
|
||||
"-t " + table + " --delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string,
|
||||
destination netip.Prefix, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destination.String())
|
||||
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runMixedIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Config) AcceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool,
|
||||
) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp"
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
protocol, connection.Port)
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
// Thanks to @npawelek.
|
||||
func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
|
||||
) error {
|
||||
doIPv4 := sourceIP.Is4() && destinationSubnet.Addr().Is4()
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String())
|
||||
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptIpv6MulticastOutput accepts outgoing traffic to the IPv6 multicast address
|
||||
// ff02::1:ff00:0/104, which is used for NDP (Neighbor Discovery Protocol) to resolve
|
||||
// IPv6 addresses to MAC addresses. If intf is empty, it is set to "*" which means
|
||||
// all interfaces. If remove is true, the rule is removed instead of added.
|
||||
func (c *Config) AcceptIpv6MulticastOutput(ctx context.Context,
|
||||
intf string, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -d ff02::1:ff00:0/104 -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag)
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptInputToPort accepts incoming traffic on the specified port, for both TCP and UDP
|
||||
// protocols, on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added. This is used for port forwarding, with
|
||||
// intf set to the VPN tunnel interface.
|
||||
func (c *Config) AcceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectPort redirects incoming traffic on the specified source port to the
|
||||
// specified destination port, for both TCP and UDP protocols, on the interface intf.
|
||||
// If intf is empty, it is set to "*" which means all interfaces. If remove is true,
|
||||
// the redirection is removed instead of added. This is used for VPN server side
|
||||
// port forwarding, with intf set to the VPN tunnel interface.
|
||||
func (c *Config) RedirectPort(ctx context.Context, intf string,
|
||||
sourcePort, destinationPort uint16, remove bool,
|
||||
) (err error) {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
errMessage := err.Error()
|
||||
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
|
||||
if !remove {
|
||||
c.logger.Warn("IPv6 port redirection disabled because your kernel does not support IPv6 NAT: " + errMessage)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
|
||||
sourcePort, destinationPort, intf, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) RunUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
}()
|
||||
for _, line := range lines {
|
||||
var ipv4 bool
|
||||
var rule string
|
||||
switch {
|
||||
case strings.HasPrefix(line, "iptables "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables ")
|
||||
case strings.HasPrefix(line, "iptables-nft "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-nft ")
|
||||
case strings.HasPrefix(line, "iptables-legacy "):
|
||||
ipv4 = true
|
||||
rule = strings.TrimPrefix(line, "iptables-legacy ")
|
||||
case strings.HasPrefix(line, "ip6tables "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables ")
|
||||
case strings.HasPrefix(line, "ip6tables-nft "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-nft ")
|
||||
case strings.HasPrefix(line, "ip6tables-legacy "):
|
||||
ipv4 = false
|
||||
rule = strings.TrimPrefix(line, "ip6tables-legacy ")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstruction(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstruction(ctx, rule)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -0,0 +1,523 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chain struct {
|
||||
name string
|
||||
policy string
|
||||
packets uint64
|
||||
bytes uint64
|
||||
rules []chainRule
|
||||
}
|
||||
|
||||
type chainRule struct {
|
||||
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||
packets uint64
|
||||
bytes uint64
|
||||
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||
protocol string // "icmp", "tcp", "udp" or "" for all protocols.
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
sourcePort uint16 // Not specified if set to zero.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
type mark struct {
|
||||
invert bool
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// pkts bytes target prot opt in out source destination
|
||||
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||
|
||||
// Filter out lines starting with a '#' character
|
||||
lines := make([]string, 0, len(linesWithComments))
|
||||
for _, line := range linesWithComments {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check for the legend line
|
||||
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
if len(lines) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
c.rules = make([]chainRule, len(lines))
|
||||
for i, line := range lines {
|
||||
c.rules[i], err = parseChainRuleLine(line)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||
// For example, it can parse the following line:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// It returns a chain struct with the parsed data.
|
||||
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
runesToRemove := []rune{'(', ')', ','}
|
||||
for _, r := range runesToRemove {
|
||||
line = strings.ReplaceAll(line, string(r), "")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
indexToExpectedValue := map[int]string{
|
||||
0: "Chain",
|
||||
2: "policy",
|
||||
5: "packets",
|
||||
7: "bytes",
|
||||
}
|
||||
for index, expectedValue := range indexToExpectedValue {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
base.policy = fields[3]
|
||||
err = checkTarget(base.policy)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||
}
|
||||
|
||||
packets, err := parseMetricSize(fields[4])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
base.packets = packets
|
||||
|
||||
bytes, err := parseMetricSize(fields[6])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
base.bytes = bytes
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) > minFields {
|
||||
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
numIndex = iota
|
||||
packetsIndex
|
||||
bytesIndex
|
||||
targetIndex
|
||||
protocolIndex
|
||||
optIndex
|
||||
inputInterfaceIndex
|
||||
outputInterfaceIndex
|
||||
sourceIndex
|
||||
destinationIndex
|
||||
)
|
||||
|
||||
switch fieldIndex {
|
||||
case numIndex:
|
||||
rule.lineNumber, err = parseLineNumber(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing line number: %w", err)
|
||||
}
|
||||
case packetsIndex:
|
||||
rule.packets, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
case bytesIndex:
|
||||
rule.bytes, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing protocol: %w", err)
|
||||
}
|
||||
case optIndex: // ignored
|
||||
case inputInterfaceIndex:
|
||||
rule.inputInterface = field
|
||||
case outputInterfaceIndex:
|
||||
rule.outputInterface = field
|
||||
case sourceIndex:
|
||||
rule.source, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case destinationIndex:
|
||||
rule.destination, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
i := 0
|
||||
for i < len(optionalFields) {
|
||||
switch optionalFields[i] {
|
||||
case "udp":
|
||||
i++
|
||||
consumed, err := parseUDPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing UDP optional fields: %w", err)
|
||||
}
|
||||
i += consumed
|
||||
case "tcp":
|
||||
i++
|
||||
consumed, err := parseTCPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing TCP optional fields: %w", err)
|
||||
}
|
||||
i += consumed
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "ports":
|
||||
i++
|
||||
ports, err := parsePortsCSV(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
i++
|
||||
case "mark":
|
||||
i++
|
||||
mark, consumed, err := parseMark(optionalFields[i:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mark: %w", err)
|
||||
}
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a UDP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a TCP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "flags:"):
|
||||
rule.tcpFlags, err = parseTCPFlags(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseDestinationPort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
func parseSourcePort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "spt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
var err error
|
||||
for i, maskFlag := range maskFlags {
|
||||
mask[i], err = parseTCPFlag(maskFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
|
||||
}
|
||||
}
|
||||
comparisonFlags := strings.Split(fields[1], ",")
|
||||
comparison := make([]tcpFlag, len(comparisonFlags))
|
||||
for i, comparisonFlag := range comparisonFlags {
|
||||
comparison[i], err = parseTCPFlag(comparisonFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
|
||||
}
|
||||
}
|
||||
return tcpFlags{
|
||||
mask: mask,
|
||||
comparison: comparison,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
ports[i], err = parsePort(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
consumed++
|
||||
if optionalFields[consumed] == "!" {
|
||||
m.invert = true
|
||||
consumed++
|
||||
}
|
||||
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var ErrTargetUnknown = errors.New("unknown target")
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0", "all":
|
||||
case "1", "icmp":
|
||||
protocol = "icmp"
|
||||
case "6", "tcp":
|
||||
protocol = "tcp"
|
||||
case "17", "udp":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
multiplerLetterToValue := map[byte]uint64{
|
||||
'K': 1000,
|
||||
'M': 1000000,
|
||||
'G': 1000000000,
|
||||
'T': 1000000000000,
|
||||
}
|
||||
|
||||
lastCharacter := size[len(size)-1]
|
||||
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||
if ok { // multiplier present
|
||||
size = size[:len(size)-1]
|
||||
} else {
|
||||
multiplier = 1
|
||||
}
|
||||
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
},
|
||||
"no_rule": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
},
|
||||
},
|
||||
"some_rules": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
3 0 0 ACCEPT 1 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
4 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||
5 0 0 ACCEPT all -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||
`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
rules: []chainRule{
|
||||
{
|
||||
lineNumber: 1,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "udp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 2,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "tcp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 3,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "icmp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
{
|
||||
lineNumber: 4,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "DROP",
|
||||
protocol: "",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
{
|
||||
lineNumber: 5,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package iptables
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . CmdRunner,Logger
|
||||
@@ -0,0 +1,121 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall/iptables (interfaces: CmdRunner,Logger)
|
||||
|
||||
// Package iptables is a generated GoMock package.
|
||||
package iptables
|
||||
|
||||
import (
|
||||
exec "os/exec"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCmdRunner is a mock of CmdRunner interface.
|
||||
type MockCmdRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCmdRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockCmdRunnerMockRecorder is the mock recorder for MockCmdRunner.
|
||||
type MockCmdRunnerMockRecorder struct {
|
||||
mock *MockCmdRunner
|
||||
}
|
||||
|
||||
// NewMockCmdRunner creates a new mock instance.
|
||||
func NewMockCmdRunner(ctrl *gomock.Controller) *MockCmdRunner {
|
||||
mock := &MockCmdRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockCmdRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCmdRunner) EXPECT() *MockCmdRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockCmdRunner) Run(arg0 *exec.Cmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockCmdRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCmdRunner)(nil).Run), arg0)
|
||||
}
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
|
||||
// Warn mocks base method.
|
||||
func (m *MockLogger) Warn(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Warn", arg0)
|
||||
}
|
||||
|
||||
// Warn indicates an expected call of Warn.
|
||||
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
sourcePort uint16 // if zero, there is no source port
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case i.sourcePort != rule.sourcePort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
case i.mark != rule.mark:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
consumed, err = preCheckInstructionFields(fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
flag := fields[0]
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match":
|
||||
consumed, err = parseMatchModule(fields, instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||
}
|
||||
case "--mark":
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(value, base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
|
||||
}
|
||||
instruction.mark.value = uint(value)
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "--sport":
|
||||
instruction.sourcePort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
instruction.destinationPort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
instruction.toPorts, err = parseToPorts(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
func parsePort(value string) (port uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(portValue), nil
|
||||
}
|
||||
|
||||
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed int, err error,
|
||||
) {
|
||||
_ = fields[consumed] // -m or --match flag already detected
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "tcp", "udp":
|
||||
consumed++
|
||||
// for now ignore the protocol match since it's auto-loaded
|
||||
// when parsing the -p/--protocol flag, and we don't need to
|
||||
// parse it twice.
|
||||
case "mark":
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||
portStrings := strings.Split(value, ",")
|
||||
toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
toPorts[i], err = parsePort(portString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toPorts, nil
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseIptablesInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("5.6.7.8/32"),
|
||||
destinationPort: 10000,
|
||||
target: "ACCEPT",
|
||||
},
|
||||
},
|
||||
"nat_redirection": {
|
||||
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
target: "REDIRECT",
|
||||
toPorts: []uint16{5678},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseIPPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
value string
|
||||
prefix netip.Prefix
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
|
||||
},
|
||||
"invalid": {
|
||||
value: "invalid",
|
||||
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
|
||||
},
|
||||
"valid_ipv4_with_bits": {
|
||||
value: "10.0.0.0/16",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
|
||||
},
|
||||
"valid_ipv4_without_bits": {
|
||||
value: "10.0.0.4",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
|
||||
},
|
||||
"valid_ipv6_with_bits": {
|
||||
value: "2001:db8::/32",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
32),
|
||||
},
|
||||
"valid_ipv6_without_bits": {
|
||||
value: "2001:db8::",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
128),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prefix, err := parseIPPrefix(testCase.value)
|
||||
|
||||
assert.Equal(t, testCase.prefix, prefix)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
iptablesPathsToTry ...string,
|
||||
) (iptablesPath string, err error) {
|
||||
iptablesPathToUnsupportedMessage := make(map[string]string, len(iptablesPathsToTry))
|
||||
for _, pathToTest := range iptablesPathsToTry {
|
||||
ok, unsupportedMessage, err := testIptablesPath(ctx, pathToTest, runner)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("for %s: %w", pathToTest, err)
|
||||
} else if ok {
|
||||
iptablesPath = pathToTest
|
||||
break
|
||||
}
|
||||
iptablesPathToUnsupportedMessage[pathToTest] = unsupportedMessage
|
||||
}
|
||||
|
||||
if iptablesPath != "" {
|
||||
// some paths may be unsupported but that does not matter
|
||||
// since we found one working.
|
||||
return iptablesPath, nil
|
||||
}
|
||||
|
||||
allArePermissionDenied := true
|
||||
allUnsupportedMessages := make(sort.StringSlice, 0, len(iptablesPathToUnsupportedMessage))
|
||||
for iptablesPath, unsupportedMessage := range iptablesPathToUnsupportedMessage {
|
||||
if !isPermissionDenied(unsupportedMessage) {
|
||||
allArePermissionDenied = false
|
||||
}
|
||||
unsupportedMessage = iptablesPath + ": " + unsupportedMessage
|
||||
allUnsupportedMessages = append(allUnsupportedMessages, unsupportedMessage)
|
||||
}
|
||||
|
||||
allUnsupportedMessages.Sort() // predictable order for tests
|
||||
|
||||
if allArePermissionDenied {
|
||||
// If the error is related to a denied permission for all iptables path,
|
||||
// return an error describing what to do from an end-user perspective.
|
||||
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
ErrNotSupported, strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
func testIptablesPath(ctx context.Context, path string,
|
||||
runner CmdRunner) (ok bool, unsupportedMessage string,
|
||||
criticalErr error,
|
||||
) {
|
||||
// Just listing iptables rules often work but we need
|
||||
// to modify them to ensure we can support the iptables
|
||||
// being tested.
|
||||
|
||||
// Append a test rule with a random interface name to the OUTPUT table.
|
||||
// This should not affect existing rules or the network traffic.
|
||||
testInterfaceName := randomInterfaceName()
|
||||
cmd := exec.CommandContext(ctx, path,
|
||||
"-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
// Remove the random rule added previously for test.
|
||||
cmd = exec.CommandContext(ctx, path,
|
||||
"-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
// this is a critical error, we want to make sure our test rule gets removed.
|
||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
// Set policy as the existing policy so no mutation is done.
|
||||
// This is an extra check for some buggy kernels where setting the policy
|
||||
// does not work.
|
||||
cmd = exec.CommandContext(ctx, path, "-nL", "INPUT")
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
var inputPolicy string
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
inputPolicy, ok = extractInputPolicy(line)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if inputPolicy == "" {
|
||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
// Set the policy for the INPUT table to the existing policy found.
|
||||
cmd = exec.CommandContext(ctx, path, "--policy", "INPUT", inputPolicy)
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
return true, "", nil // success
|
||||
}
|
||||
|
||||
func isPermissionDenied(errMessage string) (ok bool) {
|
||||
const permissionDeniedString = "Permission denied (you must be root)"
|
||||
return strings.Contains(errMessage, permissionDeniedString)
|
||||
}
|
||||
|
||||
func extractInputPolicy(line string) (policy string, ok bool) {
|
||||
const prefixToFind = "Chain INPUT (policy "
|
||||
i := strings.Index(line, prefixToFind)
|
||||
if i == -1 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
startIndex := i + len(prefixToFind)
|
||||
endIndex := strings.Index(line, ")")
|
||||
if endIndex < 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
policy = line[startIndex:endIndex]
|
||||
policy = strings.TrimSpace(policy)
|
||||
if policy == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return policy, true
|
||||
}
|
||||
|
||||
func randomInterfaceName() (interfaceName string) {
|
||||
const size = 15
|
||||
letterRunes := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
|
||||
b := make([]rune, size)
|
||||
for i := range b {
|
||||
letterIndex := rand.Intn(len(letterRunes)) //nolint:gosec
|
||||
b[i] = letterRunes[letterIndex]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
"^-j$", "^DROP$")
|
||||
}
|
||||
|
||||
func newDeleteTestRuleMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-D$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
"^-j$", "^DROP$")
|
||||
}
|
||||
|
||||
func newListInputRulesMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-nL$", "^INPUT$")
|
||||
}
|
||||
|
||||
func newSetPolicyMatcher(path, inputPolicy string) *cmdMatcher { //nolint:unparam
|
||||
return newCmdMatcher(path,
|
||||
"^--policy$", "^INPUT$", "^"+inputPolicy+"$")
|
||||
}
|
||||
|
||||
func Test_checkIptablesSupport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
errDummy := errors.New("exit code 4")
|
||||
const inputPolicy = "ACCEPT"
|
||||
|
||||
testCases := map[string]struct {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
iptablesPathsToTry []string
|
||||
iptablesPath string
|
||||
errSentinel error
|
||||
errMessage string
|
||||
}{
|
||||
"critical error when checking": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
|
||||
Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")).
|
||||
Return("output", errDummy)
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrTestRuleCleanup,
|
||||
errMessage: "for path1: failed cleaning up test rule: " +
|
||||
"output (exit code 4)",
|
||||
},
|
||||
"found valid path": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
|
||||
Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher("path1")).
|
||||
Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher("path1")).
|
||||
Return("Chain INPUT (policy "+inputPolicy+")", nil)
|
||||
runner.EXPECT().Run(newSetPolicyMatcher("path1", inputPolicy)).
|
||||
Return("", nil)
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
iptablesPath: "path1",
|
||||
},
|
||||
"all permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
|
||||
Return("Permission denied (you must be root) more context", errDummy)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path2")).
|
||||
Return("context: Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNetAdminMissing,
|
||||
errMessage: "NET_ADMIN capability is missing: " +
|
||||
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
||||
"path2: context: Permission denied (you must be root) (exit code 4)",
|
||||
},
|
||||
"no valid path": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path1")).
|
||||
Return("output 1", errDummy)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher("path2")).
|
||||
Return("output 2", errDummy)
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
"path2: output 2 (exit code 4)",
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
runner := testCase.buildRunner(ctrl)
|
||||
|
||||
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
||||
|
||||
require.ErrorIs(t, err, testCase.errSentinel)
|
||||
if testCase.errSentinel != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_testIptablesPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
const path = "dummypath"
|
||||
errDummy := errors.New("exit code 4")
|
||||
const inputPolicy = "ACCEPT"
|
||||
|
||||
testCases := map[string]struct {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
ok bool
|
||||
unsupportedMessage string
|
||||
criticalErrWrapped error
|
||||
criticalErrMessage string
|
||||
}{
|
||||
"append test rule permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
|
||||
},
|
||||
"append test rule unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"remove test rule error": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrTestRuleCleanup,
|
||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||
},
|
||||
"list input rules permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
|
||||
},
|
||||
"list input rules unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"list input rules no policy": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("some\noutput", nil)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||
},
|
||||
"set policy permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nAA\n", nil)
|
||||
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "Permission denied (you must be root) (exit code 4)",
|
||||
},
|
||||
"set policy unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nBB\n", nil)
|
||||
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"success": {
|
||||
buildRunner: func(ctrl *gomock.Controller) CmdRunner {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().Run(newAppendTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newDeleteTestRuleMatcher(path)).Return("", nil)
|
||||
runner.EXPECT().Run(newListInputRulesMatcher(path)).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nCC\n", nil)
|
||||
runner.EXPECT().Run(newSetPolicyMatcher(path, inputPolicy)).
|
||||
Return("some output", nil)
|
||||
return runner
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
runner := testCase.buildRunner(ctrl)
|
||||
|
||||
ok, unsupportedMessage, criticalErr := testIptablesPath(ctx, path, runner)
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
||||
if testCase.criticalErrWrapped != nil {
|
||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
errMessage string
|
||||
ok bool
|
||||
}{
|
||||
"empty error": {},
|
||||
"other error": {
|
||||
errMessage: "some error",
|
||||
},
|
||||
"permission denied": {
|
||||
errMessage: "Permission denied (you must be root) have you tried blabla",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ok := isPermissionDenied(testCase.errMessage)
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractInputPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
line string
|
||||
policy string
|
||||
ok bool
|
||||
}{
|
||||
"empty line": {},
|
||||
"random line": {
|
||||
line: "random line",
|
||||
},
|
||||
"only first part": {
|
||||
line: "Chain INPUT (policy ",
|
||||
},
|
||||
"empty policy": {
|
||||
line: "Chain INPUT (policy )",
|
||||
},
|
||||
"ACCEPT policy": {
|
||||
line: "Chain INPUT (policy ACCEPT)",
|
||||
policy: "ACCEPT",
|
||||
ok: true,
|
||||
},
|
||||
|
||||
"ACCEPT policy with surrounding garbage": {
|
||||
line: "garbage Chain INPUT (policy ACCEPT\t) )g()arbage",
|
||||
policy: "ACCEPT",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy, ok := extractInputPolicy(testCase.line)
|
||||
|
||||
assert.Equal(t, testCase.policy, policy)
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_randomInterfaceName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const expectedRegex = `^[a-z0-9]{15}$`
|
||||
interfaceName := randomInterfaceName()
|
||||
assert.Regexp(t, expectedRegex, interfaceName)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
type tcpFlags struct {
|
||||
mask []tcpFlag
|
||||
comparison []tcpFlag
|
||||
}
|
||||
|
||||
type tcpFlag uint8
|
||||
|
||||
const (
|
||||
tcpFlagFIN tcpFlag = 1 << iota
|
||||
tcpFlagSYN
|
||||
tcpFlagRST
|
||||
tcpFlagPSH
|
||||
tcpFlagACK
|
||||
tcpFlagURG
|
||||
tcpFlagECE
|
||||
tcpFlagCWR
|
||||
)
|
||||
|
||||
func (f tcpFlag) String() string {
|
||||
switch f {
|
||||
case tcpFlagFIN:
|
||||
return "FIN"
|
||||
case tcpFlagSYN:
|
||||
return "SYN"
|
||||
case tcpFlagRST:
|
||||
return "RST"
|
||||
case tcpFlagPSH:
|
||||
return "PSH"
|
||||
case tcpFlagACK:
|
||||
return "ACK"
|
||||
case tcpFlagURG:
|
||||
return "URG"
|
||||
case tcpFlagECE:
|
||||
return "ECE"
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
|
||||
}
|
||||
for _, flag := range allFlags {
|
||||
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||
}
|
||||
|
||||
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
|
||||
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
run := c.runIptablesInstruction
|
||||
if dst.Addr().Is6() {
|
||||
run = c.runIP6tablesInstruction
|
||||
}
|
||||
revert = func(ctx context.Context) error {
|
||||
return run(ctx, revertInstruction)
|
||||
}
|
||||
err = run(ctx, instruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running instruction: %w", err)
|
||||
}
|
||||
return revert, nil
|
||||
}
|
||||
Reference in New Issue
Block a user