From d43eb1658f9772cfee607a5cf2e84aa5b0f1dafb Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 17 Feb 2026 14:15:15 +0000 Subject: [PATCH] chore(firewall): support TCP flags for future changes --- internal/firewall/delete_test.go | 4 +- internal/firewall/list.go | 127 +++++++++++++++++++++++++++---- internal/firewall/parse.go | 59 ++++++++++---- internal/firewall/parse_test.go | 2 +- internal/firewall/tcp.go | 62 +++++++++++++++ 5 files changed, 222 insertions(+), 32 deletions(-) create mode 100644 internal/firewall/tcp.go diff --git a/internal/firewall/delete_test.go b/internal/firewall/delete_test.go index 1f6b5ceb..a0f5a6fa 100644 --- a/internal/firewall/delete_test.go +++ b/internal/firewall/delete_test.go @@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) { "invalid_instruction": { instruction: "invalid", errWrapped: ErrIptablesCommandMalformed, - errMessage: "parsing iptables command: iptables command is malformed: " + - "fields count 1 is not even: \"invalid\"", + errMessage: "parsing iptables command: parsing \"invalid\": " + + "iptables command is malformed: flag \"invalid\" requires a value, but got none", }, "list_error": { instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", diff --git a/internal/firewall/list.go b/internal/firewall/list.go index 93fee9ff..75e1955a 100644 --- a/internal/firewall/list.go +++ b/internal/firewall/list.go @@ -30,6 +30,7 @@ type chainRule struct { destinationPort uint16 // Not specified if set to zero. redirPorts []uint16 // Not specified if empty. ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. + tcpFlags tcpFlags } var ErrChainListMalformed = errors.New("iptables chain list output is malformed") @@ -241,19 +242,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err } func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) { - for i := 0; i < len(optionalFields); i++ { - key := optionalFields[i] - switch key { - case "tcp", "udp": + i := 0 + for i < len(optionalFields) { + switch optionalFields[i] { + case "udp": i++ - value := optionalFields[i] - value = strings.TrimPrefix(value, "dpt:") - const base, bitLength = 10, 16 - destinationPort, err := strconv.ParseUint(value, base, bitLength) + consumed, err := parseUDPOptional(optionalFields[i:], rule) if err != nil { - return fmt.Errorf("parsing destination port %q: %w", value, err) + return fmt.Errorf("parsing UDP optional fields: %w", err) } - rule.destinationPort = uint16(destinationPort) + i += consumed + case "tcp": + i++ + consumed, err := parseTCPOptional(optionalFields[i:], rule) + if err != nil { + return fmt.Errorf("parsing TCP optional fields: %w", err) + } + i += consumed case "redir": i++ switch optionalFields[i] { @@ -264,20 +269,116 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err return fmt.Errorf("parsing redirection ports: %w", err) } rule.redirPorts = ports + i++ default: - return fmt.Errorf("%w: unexpected optional field: %s", - ErrChainRuleMalformed, optionalFields[i]) + return fmt.Errorf("%w: unexpected %q after redir", + ErrChainRuleMalformed, optionalFields[1]) } case "ctstate": i++ rule.ctstate = strings.Split(optionalFields[i], ",") + i++ default: - return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key) + return fmt.Errorf("%w: unexpected optional field: %s", + ErrChainRuleMalformed, optionalFields[i]) } } return nil } +var errUDPOptionalUnknown = errors.New("unknown UDP optional field") + +func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) { + for _, value := range optionalFields { + if !strings.ContainsRune(value, ':') { + // no longer a UDP-associated option + return consumed, nil + } + switch { + case strings.HasPrefix(value, "dpt:"): + rule.destinationPort, err = parseDestinationPort(value) + if err != nil { + return 0, fmt.Errorf("parsing destination port: %w", err) + } + consumed++ + default: + return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value) + } + } + return consumed, nil +} + +var errTCPOptionalUnknown = errors.New("unknown TCP optional field") + +func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) { + for _, value := range optionalFields { + if !strings.ContainsRune(value, ':') { + // no longer a TCP-associated option + return consumed, nil + } + switch { + case strings.HasPrefix(value, "dpt:"): + rule.destinationPort, err = parseDestinationPort(value) + if err != nil { + return 0, fmt.Errorf("parsing destination port: %w", err) + } + consumed++ + case strings.HasPrefix(value, "flags:"): + rule.tcpFlags, err = parseTCPFlags(value) + if err != nil { + return 0, fmt.Errorf("parsing TCP flags: %w", err) + } + consumed++ + default: + return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value) + } + } + return consumed, nil +} + +func parseDestinationPort(value string) (port uint16, err error) { + value = strings.TrimPrefix(value, "dpt:") + const base, bitLength = 10, 16 + destinationPort, err := strconv.ParseUint(value, base, bitLength) + if err != nil { + return 0, fmt.Errorf("parsing %q: %w", value, err) + } + return uint16(destinationPort), nil +} + +var errTCPFlagsMalformed = errors.New("TCP flags are malformed") + +func parseTCPFlags(value string) (tcpFlags, error) { + value = strings.TrimPrefix(value, "flags:") + fields := strings.Split(value, "/") + const expectedFields = 2 + if len(fields) != expectedFields { + return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:/' in %q", + errTCPFlagsMalformed, value) + } + maskFlags := strings.Split(fields[0], ",") + mask := make([]tcpFlag, len(maskFlags)) + var err error + for i, maskFlag := range maskFlags { + mask[i], err = parseTCPFlag(maskFlag) + if err != nil { + return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err) + } + } + comparisonFlags := strings.Split(fields[1], ",") + comparison := make([]tcpFlag, len(comparisonFlags)) + for i, comparisonFlag := range comparisonFlags { + comparison[i], err = parseTCPFlag(comparisonFlag) + if err != nil { + return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err) + } + } + return tcpFlags{ + mask: mask, + comparison: comparison, + }, nil +} + func parsePortsCSV(s string) (ports []uint16, err error) { if s == "" { return nil, nil diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go index d2d046cb..c5cb2d88 100644 --- a/internal/firewall/parse.go +++ b/internal/firewall/parse.go @@ -22,6 +22,7 @@ type iptablesInstruction struct { destinationPort uint16 // if zero, there is no destination port toPorts []uint16 // if empty, there is no redirection ctstate []string // if empty, there is no ctstate + tcpFlags tcpFlags } func (i *iptablesInstruction) setDefaults() { @@ -55,6 +56,9 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) ( return false case !ipPrefixesEqual(i.destination, rule.destination): return false + case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) || + !slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison): + return false default: return true } @@ -77,26 +81,43 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) } fields := strings.Fields(s) - if len(fields)%2 != 0 { - return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q", - ErrIptablesCommandMalformed, len(fields), s) - } - for i := 0; i < len(fields); i += 2 { - key := fields[i] - value := fields[i+1] - err = parseInstructionFlag(key, value, &instruction) + i := 0 + for i < len(fields) { + consumed, err := parseInstructionFlag(fields[i:], &instruction) if err != nil { return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err) } + i += consumed } instruction.setDefaults() return instruction, nil } -func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) { - switch key { +func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) { + flag := fields[0] + + // All flags use one value after the flag, except the following: + switch flag { + case "--tcp-flags": + const expected = 3 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", + ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) + } + consumed = expected + default: + const expected = 2 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires a value, but got none", + ErrIptablesCommandMalformed, flag) + } + consumed = expected + } + value := fields[1] + + switch flag { case "-t", "--table": instruction.table = value case "-D", "--delete": @@ -117,18 +138,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) ( case "-s", "--source": instruction.source, err = parseIPPrefix(value) if err != nil { - return fmt.Errorf("parsing source IP CIDR: %w", err) + return 0, fmt.Errorf("parsing source IP CIDR: %w", err) } case "-d", "--destination": instruction.destination, err = parseIPPrefix(value) if err != nil { - return fmt.Errorf("parsing destination IP CIDR: %w", err) + return 0, fmt.Errorf("parsing destination IP CIDR: %w", err) } case "--dport": const base, bitLength = 10, 16 destinationPort, err := strconv.ParseUint(value, base, bitLength) if err != nil { - return fmt.Errorf("parsing destination port: %w", err) + return 0, fmt.Errorf("parsing destination port: %w", err) } instruction.destinationPort = uint16(destinationPort) case "--ctstate": @@ -140,14 +161,20 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) ( const base, bitLength = 10, 16 port, err := strconv.ParseUint(portString, base, bitLength) if err != nil { - return fmt.Errorf("parsing port redirection: %w", err) + return 0, fmt.Errorf("parsing port redirection: %w", err) } instruction.toPorts[i] = uint16(port) } + case "--tcp-flags": + mask, comparison := value, fields[2] + instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison) + if err != nil { + return 0, fmt.Errorf("parsing TCP flags: %w", err) + } default: - return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key) + return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag) } - return nil + return consumed, nil } func parseIPPrefix(value string) (prefix netip.Prefix, err error) { diff --git a/internal/firewall/parse_test.go b/internal/firewall/parse_test.go index ae07bc6b..fa45b098 100644 --- a/internal/firewall/parse_test.go +++ b/internal/firewall/parse_test.go @@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) { "uneven_fields": { s: "-A", errWrapped: ErrIptablesCommandMalformed, - errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"", + errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none", }, "unknown_key": { s: "-x something", diff --git a/internal/firewall/tcp.go b/internal/firewall/tcp.go new file mode 100644 index 00000000..7c38f9ca --- /dev/null +++ b/internal/firewall/tcp.go @@ -0,0 +1,62 @@ +package firewall + +import ( + "errors" + "fmt" +) + +type tcpFlags struct { + mask []tcpFlag + comparison []tcpFlag +} + +type tcpFlag uint8 + +const ( + tcpFlagFIN tcpFlag = 1 << iota + tcpFlagSYN + tcpFlagRST + tcpFlagPSH + tcpFlagACK + tcpFlagURG + tcpFlagECE + tcpFlagCWR +) + +func (f tcpFlag) String() string { + switch f { + case tcpFlagFIN: + return "FIN" + case tcpFlagSYN: + return "SYN" + case tcpFlagRST: + return "RST" + case tcpFlagPSH: + return "PSH" + case tcpFlagACK: + return "ACK" + case tcpFlagURG: + return "URG" + case tcpFlagECE: + return "ECE" + case tcpFlagCWR: + return "CWR" + default: + panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f)) + } +} + +var errTCPFlagUnknown = errors.New("unknown TCP flag") + +func parseTCPFlag(s string) (tcpFlag, error) { + allFlags := []tcpFlag{ + tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH, + tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR, + } + for _, flag := range allFlags { + if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() { + return flag, nil + } + } + return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s) +}