mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
chore(firewall): support TCP flags for future changes
This commit is contained in:
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
|
||||
+114
-13
@@ -30,6 +30,7 @@ type chainRule struct {
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
tcpFlags tcpFlags
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
@@ -241,19 +242,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i := 0
|
||||
for i < len(optionalFields) {
|
||||
switch optionalFields[i] {
|
||||
case "udp":
|
||||
i++
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
consumed, err := parseUDPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
return fmt.Errorf("parsing UDP optional fields: %w", err)
|
||||
}
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
i += consumed
|
||||
case "tcp":
|
||||
i++
|
||||
consumed, err := parseTCPOptional(optionalFields[i:], rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing TCP optional fields: %w", err)
|
||||
}
|
||||
i += consumed
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
@@ -264,20 +269,116 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a UDP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
// no longer a TCP-associated option
|
||||
return consumed, nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(value, "dpt:"):
|
||||
rule.destinationPort, err = parseDestinationPort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "flags:"):
|
||||
rule.tcpFlags, err = parseTCPFlags(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseDestinationPort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing %q: %w", value, err)
|
||||
}
|
||||
return uint16(destinationPort), nil
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
var err error
|
||||
for i, maskFlag := range maskFlags {
|
||||
mask[i], err = parseTCPFlag(maskFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
|
||||
}
|
||||
}
|
||||
comparisonFlags := strings.Split(fields[1], ",")
|
||||
comparison := make([]tcpFlag, len(comparisonFlags))
|
||||
for i, comparisonFlag := range comparisonFlags {
|
||||
comparison[i], err = parseTCPFlag(comparisonFlag)
|
||||
if err != nil {
|
||||
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
|
||||
}
|
||||
}
|
||||
return tcpFlags{
|
||||
mask: mask,
|
||||
comparison: comparison,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
|
||||
+43
-16
@@ -22,6 +22,7 @@ type iptablesInstruction struct {
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
@@ -55,6 +56,9 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
@@ -77,26 +81,43 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
i := 0
|
||||
for i < len(fields) {
|
||||
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
i += consumed
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags":
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
consumed = expected
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
consumed = expected
|
||||
}
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
@@ -117,18 +138,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
@@ -140,14 +161,20 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return nil
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
|
||||
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type tcpFlags struct {
|
||||
mask []tcpFlag
|
||||
comparison []tcpFlag
|
||||
}
|
||||
|
||||
type tcpFlag uint8
|
||||
|
||||
const (
|
||||
tcpFlagFIN tcpFlag = 1 << iota
|
||||
tcpFlagSYN
|
||||
tcpFlagRST
|
||||
tcpFlagPSH
|
||||
tcpFlagACK
|
||||
tcpFlagURG
|
||||
tcpFlagECE
|
||||
tcpFlagCWR
|
||||
)
|
||||
|
||||
func (f tcpFlag) String() string {
|
||||
switch f {
|
||||
case tcpFlagFIN:
|
||||
return "FIN"
|
||||
case tcpFlagSYN:
|
||||
return "SYN"
|
||||
case tcpFlagRST:
|
||||
return "RST"
|
||||
case tcpFlagPSH:
|
||||
return "PSH"
|
||||
case tcpFlagACK:
|
||||
return "ACK"
|
||||
case tcpFlagURG:
|
||||
return "URG"
|
||||
case tcpFlagECE:
|
||||
return "ECE"
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
|
||||
}
|
||||
for _, flag := range allFlags {
|
||||
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
Reference in New Issue
Block a user