chore(firewall): support TCP flags for future changes

This commit is contained in:
Quentin McGaw
2026-02-17 14:15:15 +00:00
parent 36dfd5b631
commit d43eb1658f
5 changed files with 222 additions and 32 deletions
+114 -13
View File
@@ -30,6 +30,7 @@ type chainRule struct {
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags
}
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -241,19 +242,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i := 0
for i < len(optionalFields) {
switch optionalFields[i] {
case "udp":
i++
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
consumed, err := parseUDPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err)
return fmt.Errorf("parsing UDP optional fields: %w", err)
}
rule.destinationPort = uint16(destinationPort)
i += consumed
case "tcp":
i++
consumed, err := parseTCPOptional(optionalFields[i:], rule)
if err != nil {
return fmt.Errorf("parsing TCP optional fields: %w", err)
}
i += consumed
case "redir":
i++
switch optionalFields[i] {
@@ -264,20 +269,116 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
i++
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
return fmt.Errorf("%w: unexpected %q after redir",
ErrChainRuleMalformed, optionalFields[1])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
i++
default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
}
return nil
}
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a UDP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
}
}
return consumed, nil
}
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
for _, value := range optionalFields {
if !strings.ContainsRune(value, ':') {
// no longer a TCP-associated option
return consumed, nil
}
switch {
case strings.HasPrefix(value, "dpt:"):
rule.destinationPort, err = parseDestinationPort(value)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
consumed++
case strings.HasPrefix(value, "flags:"):
rule.tcpFlags, err = parseTCPFlags(value)
if err != nil {
return 0, fmt.Errorf("parsing TCP flags: %w", err)
}
consumed++
default:
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
}
}
return consumed, nil
}
func parseDestinationPort(value string) (port uint16, err error) {
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing %q: %w", value, err)
}
return uint16(destinationPort), nil
}
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
func parseTCPFlags(value string) (tcpFlags, error) {
value = strings.TrimPrefix(value, "flags:")
fields := strings.Split(value, "/")
const expectedFields = 2
if len(fields) != expectedFields {
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
errTCPFlagsMalformed, value)
}
maskFlags := strings.Split(fields[0], ",")
mask := make([]tcpFlag, len(maskFlags))
var err error
for i, maskFlag := range maskFlags {
mask[i], err = parseTCPFlag(maskFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
}
}
comparisonFlags := strings.Split(fields[1], ",")
comparison := make([]tcpFlag, len(comparisonFlags))
for i, comparisonFlag := range comparisonFlags {
comparison[i], err = parseTCPFlag(comparisonFlag)
if err != nil {
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
}
}
return tcpFlags{
mask: mask,
comparison: comparison,
}, nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil