chore(firewall): support TCP flags for future changes

This commit is contained in:
Quentin McGaw
2026-02-17 14:15:15 +00:00
parent 36dfd5b631
commit d43eb1658f
5 changed files with 222 additions and 32 deletions
+43 -16
View File
@@ -22,6 +22,7 @@ type iptablesInstruction struct {
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags
}
func (i *iptablesInstruction) setDefaults() {
@@ -55,6 +56,9 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false
default:
return true
}
@@ -77,26 +81,43 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "--tcp-flags":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
consumed = expected
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
consumed = expected
}
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
@@ -117,18 +138,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
return 0, fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
@@ -140,14 +161,20 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
case "--tcp-flags":
mask, comparison := value, fields[2]
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
}
return nil
return consumed, nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {