mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-15 21:16:02 +02:00
hotfix(pmtud/tcp): block kernel from racing to send RST packets
- this makes PMTUD TCP reliable - this only works on kernels with the mark module - on kernels without the mark module, the icmp pmtud mtu found is used
This commit is contained in:
@@ -31,6 +31,12 @@ type chainRule struct {
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
type mark struct {
|
||||
invert bool
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
@@ -278,6 +284,14 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
i++
|
||||
case "mark":
|
||||
i++
|
||||
mark, consumed, err := parseMark(optionalFields[i:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mark: %w", err)
|
||||
}
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
@@ -397,6 +411,32 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
consumed++
|
||||
if optionalFields[consumed] == "!" {
|
||||
m.invert = true
|
||||
consumed++
|
||||
}
|
||||
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
|
||||
@@ -23,6 +23,7 @@ type iptablesInstruction struct {
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
tcpFlags tcpFlags
|
||||
mark mark
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
@@ -59,6 +60,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||
return false
|
||||
case i.mark != rule.mark:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
@@ -100,7 +103,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
|
||||
// All flags use one value after the flag, except the following:
|
||||
switch flag {
|
||||
case "--tcp-flags":
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
@@ -130,7 +133,30 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match": // ignore match
|
||||
case "-m", "--match":
|
||||
consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields.
|
||||
switch value {
|
||||
case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded
|
||||
case "mark":
|
||||
switch fields[2] {
|
||||
case "!":
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value)
|
||||
}
|
||||
case "--mark":
|
||||
const base = 0 // auto-detect
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(value, base, bits)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
|
||||
}
|
||||
instruction.mark.value = uint(value)
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
)
|
||||
|
||||
type tcpFlags struct {
|
||||
@@ -60,3 +63,35 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||
// for any TCP packets not marked with the excludeMark given.
|
||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||
addrPort netip.AddrPort, excludeMark int) (
|
||||
revert func(ctx context.Context) error, err error,
|
||||
) {
|
||||
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||
}
|
||||
|
||||
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||
run := c.runIptablesInstruction
|
||||
if addrPort.Addr().Is6() {
|
||||
run = c.runIP6tablesInstruction
|
||||
}
|
||||
revert = func(ctx context.Context) error {
|
||||
return run(ctx, revertInstruction)
|
||||
}
|
||||
err = run(ctx, instruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running instruction: %w", err)
|
||||
}
|
||||
return revert, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user