Files
gluetun/internal/firewall/tcp.go
T
Quentin McGaw 04d7cef294 hotfix(pmtud/tcp): block kernel from racing to send RST packets
- this makes PMTUD TCP reliable
- this only works on kernels with the mark module
- on kernels without the mark module, the icmp pmtud mtu found is used
2026-02-17 21:46:24 +00:00

98 lines
2.4 KiB
Go

package firewall
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
)
type tcpFlags struct {
mask []tcpFlag
comparison []tcpFlag
}
type tcpFlag uint8
const (
tcpFlagFIN tcpFlag = 1 << iota
tcpFlagSYN
tcpFlagRST
tcpFlagPSH
tcpFlagACK
tcpFlagURG
tcpFlagECE
tcpFlagCWR
)
func (f tcpFlag) String() string {
switch f {
case tcpFlagFIN:
return "FIN"
case tcpFlagSYN:
return "SYN"
case tcpFlagRST:
return "RST"
case tcpFlagPSH:
return "PSH"
case tcpFlagACK:
return "ACK"
case tcpFlagURG:
return "URG"
case tcpFlagECE:
return "ECE"
case tcpFlagCWR:
return "CWR"
default:
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
}
}
var errTCPFlagUnknown = errors.New("unknown TCP flag")
func parseTCPFlag(s string) (tcpFlag, error) {
allFlags := []tcpFlag{
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
tcpFlagACK, tcpFlagURG, tcpFlagECE, tcpFlagCWR,
}
for _, flag := range allFlags {
if s == fmt.Sprintf("%#02x", uint8(flag)) || s == flag.String() {
return flag, nil
}
}
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
}
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
addrPort netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
if err != nil && errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
run := c.runIptablesInstruction
if addrPort.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}