chore: do not use sentinel errors when unneeded

- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly
- only use sentinel errors when it has to be checked using `errors.Is`
- replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error
- exclude the sentinel error definition requirement from .golangci.yml
- update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
This commit is contained in:
Quentin McGaw
2026-05-02 00:50:16 +00:00
parent 9b6f048fe8
commit 4a78989d9d
172 changed files with 666 additions and 1433 deletions
+5 -9
View File
@@ -57,18 +57,15 @@ func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
},
@@ -78,7 +75,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
runner := NewMockCmdRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
Return("", errors.New("test error"))
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -86,7 +83,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
@@ -120,7 +116,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
@@ -131,7 +127,6 @@ func Test_deleteIPTablesRule(t *testing.T) {
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
@@ -177,9 +172,10 @@ func Test_deleteIPTablesRule(t *testing.T) {
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+1 -3
View File
@@ -82,13 +82,11 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
return nil
}
var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
return fmt.Errorf("policy is not valid: %s", policy)
}
return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy,
+9 -12
View File
@@ -2,7 +2,6 @@ package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
@@ -13,10 +12,8 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrPolicyUnknown = errors.New("unknown policy")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
const (
needIP6Tables = "ip6tables is required, please upgrade your kernel"
)
func appendOrDelete(remove bool) string {
@@ -36,7 +33,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
words := strings.Fields(output)
const minWords = 2
if len(words) < minWords {
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
return "", fmt.Errorf("iptables version string is too short: %s", output)
}
return "iptables " + words[1], nil
}
@@ -102,7 +99,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
return fmt.Errorf("unknown policy: %s", policy)
}
return c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy,
@@ -129,7 +126,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
return c.runIptablesInstruction(ctx, instruction)
}
if c.ip6Tables == "" {
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -157,7 +154,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
if connection.IP.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -175,7 +172,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
if ip.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -200,7 +197,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
if doIPv4 {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
@@ -350,7 +347,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
case ipv4:
err = c.runIptablesInstructionNoSave(ctx, rule)
case c.ip6Tables == "":
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
default: // ipv6
err = c.runIP6tablesInstructionNoSave(ctx, rule)
}
+27 -47
View File
@@ -40,8 +40,6 @@ type mark struct {
value uint
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
@@ -63,8 +61,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
@@ -77,8 +75,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
@@ -111,8 +109,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
expectedNumberOfFields, line)
}
// Sanity checks
@@ -126,8 +124,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
@@ -152,19 +150,17 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
return base, nil
}
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
return chainRule{}, errors.New("chain rule is malformed: empty line")
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
}
for fieldIndex, field := range fields[:minFields] {
@@ -186,7 +182,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
}
const (
@@ -278,8 +274,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
optionalFields[1])
}
case "ctstate":
i++
@@ -294,15 +290,13 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
rule.mark = mark
i += consumed
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
optionalFields[i])
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
@@ -323,14 +317,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
@@ -357,7 +349,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
}
}
return consumed, nil
@@ -373,15 +365,13 @@ func parseSourcePort(value string) (port uint16, err error) {
return parsePort(value)
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
@@ -422,8 +412,6 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil
}
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
@@ -437,42 +425,36 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
return 0, errors.New("line number is zero")
}
return uint16(lineNumber), nil
}
var ErrTargetUnknown = errors.New("unknown target")
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
return fmt.Errorf("unknown target: %s", target)
}
var ErrProtocolUnknown = errors.New("unknown protocol")
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0", "all":
@@ -483,18 +465,16 @@ func parseProtocol(s string) (protocol string, err error) {
case "17", "udp":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
return "", fmt.Errorf("unknown protocol: %s", s)
}
return protocol, nil
}
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
return n, errors.New("metric size is malformed: empty string")
}
//nolint:mnd
@@ -516,7 +496,7 @@ func parseMetricSize(size string) (n uint64, err error) {
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
return n, fmt.Errorf("metric size is malformed: %w", err)
}
n *= multiplier
return n, nil
+3 -7
View File
@@ -13,30 +13,25 @@ func Test_parseChain(t *testing.T) {
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
@@ -135,9 +130,10 @@ num pkts bytes target prot opt in out source destinati
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+10 -12
View File
@@ -80,11 +80,9 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
}
fields := strings.Fields(s)
@@ -173,7 +171,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default:
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
}
return consumed, nil
}
@@ -185,15 +183,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
flag, strings.Join(fields, " "))
}
return expected, nil
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
flag)
}
return expected, nil
}
@@ -239,12 +237,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
consumed++
instruction.mark.invert = true
default:
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s",
ErrIptablesCommandMalformed, fields[consumed])
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
fields[consumed])
}
return consumed, nil
}
+3 -6
View File
@@ -13,21 +13,17 @@ func Test_parseIptablesInstruction(t *testing.T) {
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
@@ -74,9 +70,10 @@ func Test_parseIptablesInstruction(t *testing.T) {
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
+4 -9
View File
@@ -10,12 +10,7 @@ import (
"strings"
)
var (
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
ErrInputPolicyNotFound = errors.New("input policy not found")
ErrNotSupported = errors.New("no iptables supported found")
)
var ErrNotSupported = errors.New("no iptables supported found")
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
iptablesPathsToTry ...string,
@@ -53,7 +48,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
if allArePermissionDenied {
// If the error is related to a denied permission for all iptables path,
// return an error describing what to do from an end-user perspective.
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
}
return "", fmt.Errorf("%w: errors encountered are: %s",
@@ -85,7 +80,7 @@ func testIptablesPath(ctx context.Context, path string,
output, err = runner.Run(cmd)
if err != nil {
// this is a critical error, we want to make sure our test rule gets removed.
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
return false, "", criticalErr
}
@@ -108,7 +103,7 @@ func testIptablesPath(ctx context.Context, path string,
}
if inputPolicy == "" {
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
return false, "", criticalErr
}
+6 -12
View File
@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newAppendTestRuleMatcher(path string) *cmdMatcher {
@@ -43,7 +42,6 @@ func Test_checkIptablesSupport(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner
iptablesPathsToTry []string
iptablesPath string
errSentinel error
errMessage string
}{
"critical error when checking": {
@@ -56,7 +54,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrTestRuleCleanup,
errMessage: "for path1: failed cleaning up test rule: " +
"output (exit code 4)",
},
@@ -86,7 +83,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNetAdminMissing,
errMessage: "NET_ADMIN capability is missing: " +
"path1: Permission denied (you must be root) more context (exit code 4); " +
"path2: context: Permission denied (you must be root) (exit code 4)",
@@ -101,7 +97,6 @@ func Test_checkIptablesSupport(t *testing.T) {
return runner
},
iptablesPathsToTry: []string{"path1", "path2"},
errSentinel: ErrNotSupported,
errMessage: "no iptables supported found: " +
"errors encountered are: " +
"path1: output 1 (exit code 4); " +
@@ -118,9 +113,10 @@ func Test_checkIptablesSupport(t *testing.T) {
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
require.ErrorIs(t, err, testCase.errSentinel)
if testCase.errSentinel != nil {
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.iptablesPath, iptablesPath)
})
@@ -139,7 +135,6 @@ func Test_testIptablesPath(t *testing.T) {
buildRunner func(ctrl *gomock.Controller) CmdRunner
ok bool
unsupportedMessage string
criticalErrWrapped error
criticalErrMessage string
}{
"append test rule permission denied": {
@@ -168,7 +163,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some output", errDummy)
return runner
},
criticalErrWrapped: ErrTestRuleCleanup,
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
},
"list input rules permission denied": {
@@ -202,7 +196,6 @@ func Test_testIptablesPath(t *testing.T) {
Return("some\noutput", nil)
return runner
},
criticalErrWrapped: ErrInputPolicyNotFound,
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
},
"set policy permission denied": {
@@ -257,9 +250,10 @@ func Test_testIptablesPath(t *testing.T) {
assert.Equal(t, testCase.ok, ok)
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
if testCase.criticalErrWrapped != nil {
if testCase.criticalErrMessage != "" {
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
} else {
assert.NoError(t, criticalErr)
}
})
}
+2 -4
View File
@@ -45,12 +45,10 @@ func (f tcpFlag) String() string {
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
panic(fmt.Sprintf("unknown TCP flag: %d", f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
@@ -61,7 +59,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
return 0, fmt.Errorf("unknown TCP flag: %s", s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")