package firewall import ( "errors" "fmt" "net/netip" "slices" "strconv" "strings" ) type operation uint8 const ( opNone operation = iota opAppend opDelete opInsert opReplace ) type iptablesInstruction struct { table string // defaults to "filter", and can be "nat" for example. operation operation chain string // for example INPUT, PREROUTING. Cannot be empty. target string // for example ACCEPT. Can be empty. protocol string // "tcp" or "udp" or "" for all protocols. inputInterface string // for example "tun0" or "" for any interface. outputInterface string // for example "tun0" or "" for any interface. source netip.Prefix // if not valid, then it is unspecified. destination netip.Prefix // if not valid, then it is unspecified. destinationPort uint16 // if zero, there is no destination port toPorts []uint16 // if empty, there is no redirection ctstate []string // if empty, there is no ctstate lineNumber uint16 // for replace operation, the line number to replace } func (i *iptablesInstruction) setDefaults() { if i.table == "" { i.table = "filter" } } // equalToRule ignores the append boolean flag of the instruction to compare against the rule. func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) { switch { case i.table != table: return false case i.chain != chain: return false case i.target != rule.target: return false case i.protocol != rule.protocol: return false case i.destinationPort != rule.destinationPort: return false case !slices.Equal(i.toPorts, rule.redirPorts): return false case !slices.Equal(i.ctstate, rule.ctstate): return false case !networkInterfacesEqual(i.inputInterface, rule.inputInterface): return false case !networkInterfacesEqual(i.outputInterface, rule.outputInterface): return false case !ipPrefixesEqual(i.source, rule.source): return false case !ipPrefixesEqual(i.destination, rule.destination): return false default: return true } } func (i *iptablesInstruction) String() string { var sb strings.Builder if i.table != "" && i.table != "filter" { sb.WriteString(fmt.Sprintf("-t %s ", i.table)) } switch i.operation { case opNone: panic("no operation specified") case opAppend: sb.WriteString(fmt.Sprintf("--append %s ", i.chain)) case opDelete: sb.WriteString(fmt.Sprintf("--delete %s ", i.chain)) case opInsert: sb.WriteString(fmt.Sprintf("--insert %s ", i.chain)) case opReplace: sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber)) } if i.inputInterface != "" { sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface)) } if i.outputInterface != "" { sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface)) } if i.protocol != "" { sb.WriteString(fmt.Sprintf("-p %s ", i.protocol)) } if i.source.IsValid() { sb.WriteString(fmt.Sprintf("-s %s ", i.source.String())) } if i.destination.IsValid() { sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String())) } if i.destinationPort != 0 { sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort)) } if len(i.ctstate) > 0 { sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ","))) } if len(i.toPorts) > 0 { var portStrings []string for _, port := range i.toPorts { portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10)) } sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ","))) } if i.target != "" { sb.WriteString(fmt.Sprintf("-j %s ", i.target)) } return strings.TrimSpace(sb.String()) } // instruction can be "" which equivalent to the "*" chain rule interface. func networkInterfacesEqual(instruction, chainRule string) bool { return instruction == chainRule || (instruction == "" && chainRule == "*") } func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool { return instruction == chainRule || (!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified()) } var ErrIptablesCommandMalformed = errors.New("iptables command is malformed") func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) { if s == "" { return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) } fields := strings.Fields(s) i := 0 for i < len(fields) { consumed, err := parseInstructionFlag(fields[i:], &instruction) if err != nil { return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err) } i += consumed } instruction.setDefaults() return instruction, nil } func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) { flag := fields[0] // All flags use one value after the flag, except the following: switch flag { case "-R", "--replace": const expected = 3 if len(fields) < expected { return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) } consumed = expected default: const expected = 2 if len(fields) < expected { return 0, fmt.Errorf("%w: flag %q requires a value, but got none", ErrIptablesCommandMalformed, flag) } consumed = expected } value := fields[1] switch flag { case "-t", "--table": instruction.table = value case "-D", "--delete": instruction.operation = opDelete instruction.chain = value case "-A", "--append": instruction.operation = opAppend instruction.chain = value case "-I", "--insert": instruction.operation = opInsert instruction.chain = value case "-R", "--replace": instruction.operation = opReplace instruction.chain = value const base, bits = 10, 16 n, err := strconv.ParseUint(fields[2], base, bits) if err != nil { return 0, fmt.Errorf("parsing line number for --replace operation: %w", err) } instruction.lineNumber = uint16(n) case "-j", "--jump": instruction.target = value case "-p", "--protocol": instruction.protocol = value case "-m", "--match": // ignore match case "-i", "--in-interface": instruction.inputInterface = value case "-o", "--out-interface": instruction.outputInterface = value case "-s", "--source": instruction.source, err = parseIPPrefix(value) if err != nil { return 0, fmt.Errorf("parsing source IP CIDR: %w", err) } case "-d", "--destination": instruction.destination, err = parseIPPrefix(value) if err != nil { return 0, fmt.Errorf("parsing destination IP CIDR: %w", err) } case "--dport": const base, bitLength = 10, 16 destinationPort, err := strconv.ParseUint(value, base, bitLength) if err != nil { return 0, fmt.Errorf("parsing destination port: %w", err) } instruction.destinationPort = uint16(destinationPort) case "--ctstate": instruction.ctstate = strings.Split(value, ",") case "--to-ports": portStrings := strings.Split(value, ",") instruction.toPorts = make([]uint16, len(portStrings)) for i, portString := range portStrings { const base, bitLength = 10, 16 port, err := strconv.ParseUint(portString, base, bitLength) if err != nil { return 0, fmt.Errorf("parsing port redirection: %w", err) } instruction.toPorts[i] = uint16(port) } default: return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag) } return consumed, nil } func parseIPPrefix(value string) (prefix netip.Prefix, err error) { slashIndex := strings.Index(value, "/") if slashIndex >= 0 { return netip.ParsePrefix(value) } ip, err := netip.ParseAddr(value) if err != nil { return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err) } return netip.PrefixFrom(ip, ip.BitLen()), nil }