mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
chore(pmtud/tcp): restrict temp firewall rules to source ip and source port
This commit is contained in:
+21
-10
@@ -26,6 +26,7 @@ type chainRule struct {
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
sourcePort uint16 // Not specified if set to zero.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
@@ -315,6 +316,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
@@ -337,6 +344,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "spt:"):
|
||||
rule.sourcePort, err = parseSourcePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
consumed++
|
||||
case strings.HasPrefix(value, "flags:"):
|
||||
rule.tcpFlags, err = parseTCPFlags(value)
|
||||
if err != nil {
|
||||
@@ -352,12 +365,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
|
||||
func parseDestinationPort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing %q: %w", value, err)
|
||||
}
|
||||
return uint16(destinationPort), nil
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
func parseSourcePort(value string) (port uint16, err error) {
|
||||
value = strings.TrimPrefix(value, "spt:")
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
@@ -401,12 +414,10 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
ports[i], err = parsePort(field)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
return nil, err
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
+89
-44
@@ -18,6 +18,7 @@ type iptablesInstruction struct {
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
sourcePort uint16 // if zero, there is no source port
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
@@ -45,6 +46,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case i.sourcePort != rule.sourcePort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
@@ -99,25 +102,11 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
||||
}
|
||||
|
||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
consumed = expected
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
consumed = expected
|
||||
consumed, err = preCheckInstructionFields(fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
flag := fields[0]
|
||||
value := fields[1]
|
||||
|
||||
switch flag {
|
||||
@@ -134,20 +123,9 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match":
|
||||
consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields.
|
||||
switch value {
|
||||
case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded
|
||||
case "mark":
|
||||
switch fields[2] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value)
|
||||
consumed, err = parseMatchModule(fields, instruction)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||
}
|
||||
case "--mark":
|
||||
const base = 0 // auto-detect
|
||||
@@ -166,30 +144,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "--sport":
|
||||
instruction.sourcePort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
instruction.destinationPort, err = parsePort(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
instruction.toPorts, err = parseToPorts(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
case "--tcp-flags":
|
||||
mask, comparison := value, fields[2]
|
||||
@@ -203,6 +178,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
flag := fields[0]
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
@@ -215,3 +211,52 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
func parsePort(value string) (port uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(portValue), nil
|
||||
}
|
||||
|
||||
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed int, err error,
|
||||
) {
|
||||
_ = fields[consumed] // -m or --match flag already detected
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "tcp", "udp":
|
||||
consumed++
|
||||
// for now ignore the protocol match since it's auto-loaded
|
||||
// when parsing the -p/--protocol flag, and we don't need to
|
||||
// parse it twice.
|
||||
case "mark":
|
||||
consumed++
|
||||
switch fields[consumed] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||
portStrings := strings.Split(value, ",")
|
||||
toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
toPorts[i], err = parsePort(portString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toPorts, nil
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module li
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
addrPort netip.AddrPort, excludeMark int) (
|
||||
src, dst netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||
@@ -79,11 +79,12 @@ func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||
}
|
||||
|
||||
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
|
||||
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||
run := c.runIptablesInstruction
|
||||
if addrPort.Addr().Is6() {
|
||||
if dst.Addr().Is6() {
|
||||
run = c.runIP6tablesInstruction
|
||||
}
|
||||
revert = func(ctx context.Context) error {
|
||||
|
||||
Reference in New Issue
Block a user