mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-17 06:00:15 +02:00
chore: do not use sentinel errors when unneeded
- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly - only use sentinel errors when it has to be checked using `errors.Is` - replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error - exclude the sentinel error definition requirement from .golangci.yml - update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
Return("", errors.New("test error"))
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
return fmt.Errorf("policy is not valid: %s", policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
|
||||
@@ -2,7 +2,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@@ -13,10 +12,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
const (
|
||||
needIP6Tables = "ip6tables is required, please upgrade your kernel"
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
return "", fmt.Errorf("iptables version string is too short: %s", output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
}
|
||||
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
return fmt.Errorf("unknown policy: %s", policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
if ip.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstructionNoSave(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
||||
}
|
||||
|
||||
@@ -40,8 +40,6 @@ type mark struct {
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
|
||||
iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
|
||||
legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
|
||||
expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
|
||||
expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
return chainRule{}, errors.New("chain rule is malformed: empty line")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
|
||||
optionalFields[1])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
|
||||
optionalFields[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
|
||||
optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
return 0, errors.New("line number is zero")
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var ErrTargetUnknown = errors.New("unknown target")
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
return fmt.Errorf("unknown target: %s", target)
|
||||
}
|
||||
|
||||
var ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0", "all":
|
||||
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
|
||||
case "17", "udp":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
return "", fmt.Errorf("unknown protocol: %s", s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
return n, errors.New("metric size is malformed: empty string")
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
return n, fmt.Errorf("metric size is malformed: %w", err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
|
||||
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
|
||||
flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
|
||||
flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
|
||||
fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
|
||||
fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,12 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
var ErrNotSupported = errors.New("no iptables supported found")
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
iptablesPathsToTry ...string,
|
||||
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
if allArePermissionDenied {
|
||||
// If the error is related to a denied permission for all iptables path,
|
||||
// return an error describing what to do from an end-user perspective.
|
||||
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
|
||||
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
// this is a critical error, we want to make sure our test rule gets removed.
|
||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
||||
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
}
|
||||
|
||||
if inputPolicy == "" {
|
||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
||||
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
iptablesPathsToTry []string
|
||||
iptablesPath string
|
||||
errSentinel error
|
||||
errMessage string
|
||||
}{
|
||||
"critical error when checking": {
|
||||
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrTestRuleCleanup,
|
||||
errMessage: "for path1: failed cleaning up test rule: " +
|
||||
"output (exit code 4)",
|
||||
},
|
||||
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNetAdminMissing,
|
||||
errMessage: "NET_ADMIN capability is missing: " +
|
||||
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
||||
"path2: context: Permission denied (you must be root) (exit code 4)",
|
||||
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
|
||||
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
||||
|
||||
require.ErrorIs(t, err, testCase.errSentinel)
|
||||
if testCase.errSentinel != nil {
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
||||
})
|
||||
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
ok bool
|
||||
unsupportedMessage string
|
||||
criticalErrWrapped error
|
||||
criticalErrMessage string
|
||||
}{
|
||||
"append test rule permission denied": {
|
||||
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrTestRuleCleanup,
|
||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||
},
|
||||
"list input rules permission denied": {
|
||||
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some\noutput", nil)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||
},
|
||||
"set policy permission denied": {
|
||||
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
||||
if testCase.criticalErrWrapped != nil {
|
||||
if testCase.criticalErrMessage != "" {
|
||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||
} else {
|
||||
assert.NoError(t, criticalErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
panic(fmt.Sprintf("unknown TCP flag: %d", f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
return 0, fmt.Errorf("unknown TCP flag: %s", s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
Reference in New Issue
Block a user