mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-10 04:30:20 +02:00
chore(firewall): split apart iptables specific code in internal/firewall/iptables
This commit is contained in:
@@ -0,0 +1,262 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
sourcePort uint16 // if zero, there is no source port
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case i.sourcePort != rule.sourcePort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
case i.mark != rule.mark:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
consumed, err = preCheckInstructionFields(fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
flag := fields[0]
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match":
|
||||
consumed, err = parseMatchModule(fields, instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||
}
|
||||
case "--mark":
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(value, base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
|
||||
}
|
||||
instruction.mark.value = uint(value)
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "--sport":
|
||||
instruction.sourcePort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
instruction.destinationPort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
instruction.toPorts, err = parseToPorts(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
func parsePort(value string) (port uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(portValue), nil
|
||||
}
|
||||
|
||||
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed int, err error,
|
||||
) {
|
||||
_ = fields[consumed] // -m or --match flag already detected
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "tcp", "udp":
|
||||
consumed++
|
||||
// for now ignore the protocol match since it's auto-loaded
|
||||
// when parsing the -p/--protocol flag, and we don't need to
|
||||
// parse it twice.
|
||||
case "mark":
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||
portStrings := strings.Split(value, ",")
|
||||
toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
toPorts[i], err = parsePort(portString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toPorts, nil
|
||||
}
|
||||
Reference in New Issue
Block a user