Files
gluetun/internal/firewall/parse.go
T
2026-02-10 16:19:08 +00:00

257 lines
7.6 KiB
Go

package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
operation operation
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
lineNumber uint16 // for replace operation, the line number to replace
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
default:
return true
}
}
func (i *iptablesInstruction) String() string {
var sb strings.Builder
if i.table != "" && i.table != "filter" {
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
}
switch i.operation {
case opNone:
panic("no operation specified")
case opAppend:
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
case opDelete:
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
case opInsert:
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
case opReplace:
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
}
if i.inputInterface != "" {
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
}
if i.outputInterface != "" {
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
}
if i.protocol != "" {
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
}
if i.source.IsValid() {
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
}
if i.destination.IsValid() {
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
}
if i.destinationPort != 0 {
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
}
if len(i.ctstate) > 0 {
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
}
if len(i.toPorts) > 0 {
var portStrings []string
for _, port := range i.toPorts {
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
}
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
}
if i.target != "" {
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
}
return strings.TrimSpace(sb.String())
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "-R", "--replace":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
consumed = expected
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
consumed = expected
}
value := fields[1]
switch flag {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.operation = opDelete
instruction.chain = value
case "-A", "--append":
instruction.operation = opAppend
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.chain = value
case "-R", "--replace":
instruction.operation = opReplace
instruction.chain = value
const base, bits = 10, 16
n, err := strconv.ParseUint(fields[2], base, bits)
if err != nil {
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
}
instruction.lineNumber = uint16(n)
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
}
return consumed, nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}