chore(firewall): support TCP flags for future changes

This commit is contained in:
Quentin McGaw
2026-02-17 14:15:15 +00:00
parent 36dfd5b631
commit d43eb1658f
5 changed files with 222 additions and 32 deletions
+2 -2
View File
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
"invalid_instruction": { "invalid_instruction": {
instruction: "invalid", instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed, errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " + errMessage: "parsing iptables command: parsing \"invalid\": " +
"fields count 1 is not even: \"invalid\"", "iptables command is malformed: flag \"invalid\" requires a value, but got none",
}, },
"list_error": { "list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
+114 -13
View File
@@ -30,6 +30,7 @@ type chainRule struct {
destinationPort uint16 // Not specified if set to zero. destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty. redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
} }
var ErrChainListMalformed = errors.New("iptables chain list output is malformed") var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -241,19 +242,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
} }
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) { func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ { i := 0
key := optionalFields[i] for i < len(optionalFields) {
switch key { switch optionalFields[i] {
case "tcp", "udp": case "udp":
i++ i++
value := optionalFields[i] consumed, err := parseUDPOptional(optionalFields[i:], rule)
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil { if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err) return fmt.Errorf("parsing UDP optional fields: %w", err)
} }
rule.destinationPort = uint16(destinationPort) i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
case "redir": case "redir":
i++ i++
switch optionalFields[i] { switch optionalFields[i] {
@@ -264,20 +269,116 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
return fmt.Errorf("parsing redirection ports: %w", err) return fmt.Errorf("parsing redirection ports: %w", err)
} }
rule.redirPorts = ports rule.redirPorts = ports
i++
default: default:
return fmt.Errorf("%w: unexpected optional field: %s", return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[i]) ErrChainRuleMalformed, optionalFields[1])
} }
case "ctstate": case "ctstate":
i++ i++
rule.ctstate = strings.Split(optionalFields[i], ",") rule.ctstate = strings.Split(optionalFields[i], ",")
i++
default: default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key) return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
} }
} }
return nil return nil
} }
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing %q: %w", value, err)
}
return uint16(destinationPort), nil
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) { func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" { if s == "" {
return nil, nil return nil, nil
+43 -16
View File
@@ -22,6 +22,7 @@ type iptablesInstruction struct {
destinationPort uint16 // if zero, there is no destination port destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
} }
func (i *iptablesInstruction) setDefaults() { func (i *iptablesInstruction) setDefaults() {
@@ -55,6 +56,9 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false return false
case !ipPrefixesEqual(i.destination, rule.destination): case !ipPrefixesEqual(i.destination, rule.destination):
return false return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
default: default:
return true return true
} }
@@ -77,26 +81,43 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
} }
fields := strings.Fields(s) fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 { i := 0
key := fields[i] for i < len(fields) {
value := fields[i+1] consumed, err := parseInstructionFlag(fields[i:], &instruction)
err = parseInstructionFlag(key, value, &instruction)
if err != nil { if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err) return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
} }
i += consumed
} }
instruction.setDefaults() instruction.setDefaults()
return instruction, nil return instruction, nil
} }
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) { func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
switch key { flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
consumed = expected
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
consumed = expected
}
value := fields[1]
switch flag {
case "-t", "--table": case "-t", "--table":
instruction.table = value instruction.table = value
case "-D", "--delete": case "-D", "--delete":
@@ -117,18 +138,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
case "-s", "--source": case "-s", "--source":
instruction.source, err = parseIPPrefix(value) instruction.source, err = parseIPPrefix(value)
if err != nil { if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err) return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
} }
case "-d", "--destination": case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value) instruction.destination, err = parseIPPrefix(value)
if err != nil { if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err) return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
} }
case "--dport": case "--dport":
const base, bitLength = 10, 16 const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength) destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil { if err != nil {
return fmt.Errorf("parsing destination port: %w", err) return 0, fmt.Errorf("parsing destination port: %w", err)
} }
instruction.destinationPort = uint16(destinationPort) instruction.destinationPort = uint16(destinationPort)
case "--ctstate": case "--ctstate":
@@ -140,14 +161,20 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
const base, bitLength = 10, 16 const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength) port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil { if err != nil {
return fmt.Errorf("parsing port redirection: %w", err) return 0, fmt.Errorf("parsing port redirection: %w", err)
} }
instruction.toPorts[i] = uint16(port) instruction.toPorts[i] = uint16(port)
} }
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default: default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key) return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
} }
return nil return consumed, nil
} }
func parseIPPrefix(value string) (prefix netip.Prefix, err error) { func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
+1 -1
View File
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
"uneven_fields": { "uneven_fields": {
s: "-A", s: "-A",
errWrapped: ErrIptablesCommandMalformed, errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"", errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
}, },
"unknown_key": { "unknown_key": {
s: "-x something", s: "-x something",
+62
View File
@@ -0,0 +1,62 @@
package firewall
import (
"errors"
"fmt"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}