mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
chore(firewall): support TCP flags for future changes
This commit is contained in:
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
|||||||
"invalid_instruction": {
|
"invalid_instruction": {
|
||||||
instruction: "invalid",
|
instruction: "invalid",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||||
"fields count 1 is not even: \"invalid\"",
|
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"list_error": {
|
"list_error": {
|
||||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
|||||||
+114
-13
@@ -30,6 +30,7 @@ type chainRule struct {
|
|||||||
destinationPort uint16 // Not specified if set to zero.
|
destinationPort uint16 // Not specified if set to zero.
|
||||||
redirPorts []uint16 // Not specified if empty.
|
redirPorts []uint16 // Not specified if empty.
|
||||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||||
|
tcpFlags tcpFlags
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||||
@@ -241,19 +242,23 @@ func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||||
for i := 0; i < len(optionalFields); i++ {
|
i := 0
|
||||||
key := optionalFields[i]
|
for i < len(optionalFields) {
|
||||||
switch key {
|
switch optionalFields[i] {
|
||||||
case "tcp", "udp":
|
case "udp":
|
||||||
i++
|
i++
|
||||||
value := optionalFields[i]
|
consumed, err := parseUDPOptional(optionalFields[i:], rule)
|
||||||
value = strings.TrimPrefix(value, "dpt:")
|
|
||||||
const base, bitLength = 10, 16
|
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
return fmt.Errorf("parsing UDP optional fields: %w", err)
|
||||||
}
|
}
|
||||||
rule.destinationPort = uint16(destinationPort)
|
i += consumed
|
||||||
|
case "tcp":
|
||||||
|
i++
|
||||||
|
consumed, err := parseTCPOptional(optionalFields[i:], rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing TCP optional fields: %w", err)
|
||||||
|
}
|
||||||
|
i += consumed
|
||||||
case "redir":
|
case "redir":
|
||||||
i++
|
i++
|
||||||
switch optionalFields[i] {
|
switch optionalFields[i] {
|
||||||
@@ -264,20 +269,116 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
|||||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||||
}
|
}
|
||||||
rule.redirPorts = ports
|
rule.redirPorts = ports
|
||||||
|
i++
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
return fmt.Errorf("%w: unexpected %q after redir",
|
||||||
ErrChainRuleMalformed, optionalFields[i])
|
ErrChainRuleMalformed, optionalFields[1])
|
||||||
}
|
}
|
||||||
case "ctstate":
|
case "ctstate":
|
||||||
i++
|
i++
|
||||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||||
|
i++
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||||
|
ErrChainRuleMalformed, optionalFields[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||||
|
|
||||||
|
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||||
|
for _, value := range optionalFields {
|
||||||
|
if !strings.ContainsRune(value, ':') {
|
||||||
|
// no longer a UDP-associated option
|
||||||
|
return consumed, nil
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(value, "dpt:"):
|
||||||
|
rule.destinationPort, err = parseDestinationPort(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
|
}
|
||||||
|
consumed++
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return consumed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||||
|
|
||||||
|
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||||
|
for _, value := range optionalFields {
|
||||||
|
if !strings.ContainsRune(value, ':') {
|
||||||
|
// no longer a TCP-associated option
|
||||||
|
return consumed, nil
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(value, "dpt:"):
|
||||||
|
rule.destinationPort, err = parseDestinationPort(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
|
}
|
||||||
|
consumed++
|
||||||
|
case strings.HasPrefix(value, "flags:"):
|
||||||
|
rule.tcpFlags, err = parseTCPFlags(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||||
|
}
|
||||||
|
consumed++
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return consumed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDestinationPort(value string) (port uint16, err error) {
|
||||||
|
value = strings.TrimPrefix(value, "dpt:")
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing %q: %w", value, err)
|
||||||
|
}
|
||||||
|
return uint16(destinationPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||||
|
|
||||||
|
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||||
|
value = strings.TrimPrefix(value, "flags:")
|
||||||
|
fields := strings.Split(value, "/")
|
||||||
|
const expectedFields = 2
|
||||||
|
if len(fields) != expectedFields {
|
||||||
|
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||||
|
errTCPFlagsMalformed, value)
|
||||||
|
}
|
||||||
|
maskFlags := strings.Split(fields[0], ",")
|
||||||
|
mask := make([]tcpFlag, len(maskFlags))
|
||||||
|
var err error
|
||||||
|
for i, maskFlag := range maskFlags {
|
||||||
|
mask[i], err = parseTCPFlag(maskFlag)
|
||||||
|
if err != nil {
|
||||||
|
return tcpFlags{}, fmt.Errorf("parsing TCP mask flags: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
comparisonFlags := strings.Split(fields[1], ",")
|
||||||
|
comparison := make([]tcpFlag, len(comparisonFlags))
|
||||||
|
for i, comparisonFlag := range comparisonFlags {
|
||||||
|
comparison[i], err = parseTCPFlag(comparisonFlag)
|
||||||
|
if err != nil {
|
||||||
|
return tcpFlags{}, fmt.Errorf("parsing TCP comparison flags: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tcpFlags{
|
||||||
|
mask: mask,
|
||||||
|
comparison: comparison,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
+43
-16
@@ -22,6 +22,7 @@ type iptablesInstruction struct {
|
|||||||
destinationPort uint16 // if zero, there is no destination port
|
destinationPort uint16 // if zero, there is no destination port
|
||||||
toPorts []uint16 // if empty, there is no redirection
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
ctstate []string // if empty, there is no ctstate
|
ctstate []string // if empty, there is no ctstate
|
||||||
|
tcpFlags tcpFlags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesInstruction) setDefaults() {
|
func (i *iptablesInstruction) setDefaults() {
|
||||||
@@ -55,6 +56,9 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
|||||||
return false
|
return false
|
||||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||||
return false
|
return false
|
||||||
|
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||||
|
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||||
|
return false
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -77,26 +81,43 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
|||||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||||
}
|
}
|
||||||
fields := strings.Fields(s)
|
fields := strings.Fields(s)
|
||||||
if len(fields)%2 != 0 {
|
|
||||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
|
||||||
ErrIptablesCommandMalformed, len(fields), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(fields); i += 2 {
|
i := 0
|
||||||
key := fields[i]
|
for i < len(fields) {
|
||||||
value := fields[i+1]
|
consumed, err := parseInstructionFlag(fields[i:], &instruction)
|
||||||
err = parseInstructionFlag(key, value, &instruction)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||||
}
|
}
|
||||||
|
i += consumed
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction.setDefaults()
|
instruction.setDefaults()
|
||||||
return instruction, nil
|
return instruction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||||
switch key {
|
flag := fields[0]
|
||||||
|
|
||||||
|
// All flags use one value after the flag, except the following:
|
||||||
|
switch flag {
|
||||||
|
case "--tcp-flags":
|
||||||
|
const expected = 3
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||||
|
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
default:
|
||||||
|
const expected = 2
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||||
|
ErrIptablesCommandMalformed, flag)
|
||||||
|
}
|
||||||
|
consumed = expected
|
||||||
|
}
|
||||||
|
value := fields[1]
|
||||||
|
|
||||||
|
switch flag {
|
||||||
case "-t", "--table":
|
case "-t", "--table":
|
||||||
instruction.table = value
|
instruction.table = value
|
||||||
case "-D", "--delete":
|
case "-D", "--delete":
|
||||||
@@ -117,18 +138,18 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
case "-s", "--source":
|
case "-s", "--source":
|
||||||
instruction.source, err = parseIPPrefix(value)
|
instruction.source, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "-d", "--destination":
|
case "-d", "--destination":
|
||||||
instruction.destination, err = parseIPPrefix(value)
|
instruction.destination, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "--dport":
|
case "--dport":
|
||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
instruction.destinationPort = uint16(destinationPort)
|
instruction.destinationPort = uint16(destinationPort)
|
||||||
case "--ctstate":
|
case "--ctstate":
|
||||||
@@ -140,14 +161,20 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (
|
|||||||
const base, bitLength = 10, 16
|
const base, bitLength = 10, 16
|
||||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing port redirection: %w", err)
|
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||||
}
|
}
|
||||||
instruction.toPorts[i] = uint16(port)
|
instruction.toPorts[i] = uint16(port)
|
||||||
}
|
}
|
||||||
|
case "--tcp-flags":
|
||||||
|
mask, comparison := value, fields[2]
|
||||||
|
instruction.tcpFlags, err = parseTCPFlags(mask + "/" + comparison)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||||
}
|
}
|
||||||
return nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
|||||||
"uneven_fields": {
|
"uneven_fields": {
|
||||||
s: "-A",
|
s: "-A",
|
||||||
errWrapped: ErrIptablesCommandMalformed,
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||||
},
|
},
|
||||||
"unknown_key": {
|
"unknown_key": {
|
||||||
s: "-x something",
|
s: "-x something",
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tcpFlags struct {
|
||||||
|
mask []tcpFlag
|
||||||
|
comparison []tcpFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpFlag uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpFlagFIN tcpFlag = 1 << iota
|
||||||
|
tcpFlagSYN
|
||||||
|
tcpFlagRST
|
||||||
|
tcpFlagPSH
|
||||||
|
tcpFlagACK
|
||||||
|
tcpFlagURG
|
||||||
|
tcpFlagECE
|
||||||
|
tcpFlagCWR
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f tcpFlag) String() string {
|
||||||
|
switch f {
|
||||||
|
case tcpFlagFIN:
|
||||||
|
return "FIN"
|
||||||
|
case tcpFlagSYN:
|
||||||
|
return "SYN"
|
||||||
|
case tcpFlagRST:
|
||||||
|
return "RST"
|
||||||
|
case tcpFlagPSH:
|
||||||
|
return "PSH"
|
||||||
|
case tcpFlagACK:
|
||||||
|
return "ACK"
|
||||||
|
case tcpFlagURG:
|
||||||
|
return "URG"
|
||||||
|
case tcpFlagECE:
|
||||||
|
return "ECE"
|
||||||
|
case tcpFlagCWR:
|
||||||
|
return "CWR"
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||||
|
|
||||||
|
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||||
|
allFlags := []tcpFlag{
|
||||||
|
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||||
|
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
|
||||||
|
}
|
||||||
|
for _, flag := range allFlags {
|
||||||
|
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
|
||||||
|
return flag, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user