mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
chore(pmtud/tcp): restrict temp firewall rules to source ip and source port
This commit is contained in:
+21
-10
@@ -26,6 +26,7 @@ type chainRule struct {
|
|||||||
inputInterface string // input interface, for example "tun0" or "*""
|
inputInterface string // input interface, for example "tun0" or "*""
|
||||||
outputInterface string // output interface, for example "eth0" or "*""
|
outputInterface string // output interface, for example "eth0" or "*""
|
||||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||||
|
sourcePort uint16 // Not specified if set to zero.
|
||||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||||
destinationPort uint16 // Not specified if set to zero.
|
destinationPort uint16 // Not specified if set to zero.
|
||||||
redirPorts []uint16 // Not specified if empty.
|
redirPorts []uint16 // Not specified if empty.
|
||||||
@@ -315,6 +316,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
|||||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
consumed++
|
consumed++
|
||||||
|
case strings.HasPrefix(value, "spt:"):
|
||||||
|
rule.sourcePort, err = parseSourcePort(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||||
|
}
|
||||||
|
consumed++
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||||
}
|
}
|
||||||
@@ -337,6 +344,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
|||||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
consumed++
|
consumed++
|
||||||
|
case strings.HasPrefix(value, "spt:"):
|
||||||
|
rule.sourcePort, err = parseSourcePort(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||||
|
}
|
||||||
|
consumed++
|
||||||
case strings.HasPrefix(value, "flags:"):
|
case strings.HasPrefix(value, "flags:"):
|
||||||
rule.tcpFlags, err = parseTCPFlags(value)
|
rule.tcpFlags, err = parseTCPFlags(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -352,12 +365,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
|||||||
|
|
||||||
func parseDestinationPort(value string) (port uint16, err error) {
|
func parseDestinationPort(value string) (port uint16, err error) {
|
||||||
value = strings.TrimPrefix(value, "dpt:")
|
value = strings.TrimPrefix(value, "dpt:")
|
||||||
const base, bitLength = 10, 16
|
return parsePort(value)
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
}
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("parsing %q: %w", value, err)
|
func parseSourcePort(value string) (port uint16, err error) {
|
||||||
}
|
value = strings.TrimPrefix(value, "spt:")
|
||||||
return uint16(destinationPort), nil
|
return parsePort(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||||
@@ -401,12 +414,10 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
|||||||
fields := strings.Split(s, ",")
|
fields := strings.Split(s, ",")
|
||||||
ports = make([]uint16, len(fields))
|
ports = make([]uint16, len(fields))
|
||||||
for i, field := range fields {
|
for i, field := range fields {
|
||||||
const base, bitLength = 10, 16
|
ports[i], err = parsePort(field)
|
||||||
port, err := strconv.ParseUint(field, base, bitLength)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
return nil, err
|
||||||
}
|
}
|
||||||
ports[i] = uint16(port)
|
|
||||||
}
|
}
|
||||||
return ports, nil
|
return ports, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+89
-44
@@ -18,6 +18,7 @@ type iptablesInstruction struct {
|
|||||||
inputInterface string // for example "tun0" or "" for any interface.
|
inputInterface string // for example "tun0" or "" for any interface.
|
||||||
outputInterface string // for example "tun0" or "" for any interface.
|
outputInterface string // for example "tun0" or "" for any interface.
|
||||||
source netip.Prefix // if not valid, then it is unspecified.
|
source netip.Prefix // if not valid, then it is unspecified.
|
||||||
|
sourcePort uint16 // if zero, there is no source port
|
||||||
destination netip.Prefix // if not valid, then it is unspecified.
|
destination netip.Prefix // if not valid, then it is unspecified.
|
||||||
destinationPort uint16 // if zero, there is no destination port
|
destinationPort uint16 // if zero, there is no destination port
|
||||||
toPorts []uint16 // if empty, there is no redirection
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
@@ -45,6 +46,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
|||||||
return false
|
return false
|
||||||
case i.destinationPort != rule.destinationPort:
|
case i.destinationPort != rule.destinationPort:
|
||||||
return false
|
return false
|
||||||
|
case i.sourcePort != rule.sourcePort:
|
||||||
|
return false
|
||||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||||
return false
|
return false
|
||||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||||
@@ -99,25 +102,11 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
|
||||||
flag := fields[0]
|
consumed, err = preCheckInstructionFields(fields)
|
||||||
|
if err != nil {
|
||||||
// All flags use one value after the flag, except the following:
|
return 0, err
|
||||||
switch flag {
|
|
||||||
case "--tcp-flags": // -m can have 1 or 2 values
|
|
||||||
const expected = 3
|
|
||||||
if len(fields) < expected {
|
|
||||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
|
||||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
|
||||||
}
|
|
||||||
consumed = expected
|
|
||||||
default:
|
|
||||||
const expected = 2
|
|
||||||
if len(fields) < expected {
|
|
||||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
|
||||||
ErrIptablesCommandMalformed, flag)
|
|
||||||
}
|
|
||||||
consumed = expected
|
|
||||||
}
|
}
|
||||||
|
flag := fields[0]
|
||||||
value := fields[1]
|
value := fields[1]
|
||||||
|
|
||||||
switch flag {
|
switch flag {
|
||||||
@@ -134,20 +123,9 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
case "-p", "--protocol":
|
case "-p", "--protocol":
|
||||||
instruction.protocol = value
|
instruction.protocol = value
|
||||||
case "-m", "--match":
|
case "-m", "--match":
|
||||||
consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields.
|
consumed, err = parseMatchModule(fields, instruction)
|
||||||
switch value {
|
if err != nil {
|
||||||
case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded
|
return 0, fmt.Errorf("parsing match module: %w", err)
|
||||||
case "mark":
|
|
||||||
switch fields[2] {
|
|
||||||
case "!":
|
|
||||||
consumed++
|
|
||||||
instruction.mark.invert = true
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("%w: unsupported match mark with value: %s",
|
|
||||||
ErrIptablesCommandMalformed, fields[2])
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value)
|
|
||||||
}
|
}
|
||||||
case "--mark":
|
case "--mark":
|
||||||
const base = 0 // auto-detect
|
const base = 0 // auto-detect
|
||||||
@@ -166,30 +144,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
|
case "--sport":
|
||||||
|
instruction.sourcePort, err = parsePort(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing source port: %w", err)
|
||||||
|
}
|
||||||
case "-d", "--destination":
|
case "-d", "--destination":
|
||||||
instruction.destination, err = parseIPPrefix(value)
|
instruction.destination, err = parseIPPrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
}
|
}
|
||||||
case "--dport":
|
case "--dport":
|
||||||
const base, bitLength = 10, 16
|
instruction.destinationPort, err = parsePort(value)
|
||||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("parsing destination port: %w", err)
|
return 0, fmt.Errorf("parsing destination port: %w", err)
|
||||||
}
|
}
|
||||||
instruction.destinationPort = uint16(destinationPort)
|
|
||||||
case "--ctstate":
|
case "--ctstate":
|
||||||
instruction.ctstate = strings.Split(value, ",")
|
instruction.ctstate = strings.Split(value, ",")
|
||||||
case "--to-ports":
|
case "--to-ports":
|
||||||
portStrings := strings.Split(value, ",")
|
instruction.toPorts, err = parseToPorts(value)
|
||||||
instruction.toPorts = make([]uint16, len(portStrings))
|
if err != nil {
|
||||||
for i, portString := range portStrings {
|
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
||||||
const base, bitLength = 10, 16
|
|
||||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("parsing port redirection: %w", err)
|
|
||||||
}
|
|
||||||
instruction.toPorts[i] = uint16(port)
|
|
||||||
}
|
}
|
||||||
case "--tcp-flags":
|
case "--tcp-flags":
|
||||||
mask, comparison := value, fields[2]
|
mask, comparison := value, fields[2]
|
||||||
@@ -203,6 +178,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
return consumed, nil
|
return consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||||
|
flag := fields[0]
|
||||||
|
// All flags use one value after the flag, except the following:
|
||||||
|
switch flag {
|
||||||
|
case "--tcp-flags": // -m can have 1 or 2 values
|
||||||
|
const expected = 3
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||||
|
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||||
|
}
|
||||||
|
return expected, nil
|
||||||
|
default:
|
||||||
|
const expected = 2
|
||||||
|
if len(fields) < expected {
|
||||||
|
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||||
|
ErrIptablesCommandMalformed, flag)
|
||||||
|
}
|
||||||
|
return expected, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||||
slashIndex := strings.Index(value, "/")
|
slashIndex := strings.Index(value, "/")
|
||||||
if slashIndex >= 0 {
|
if slashIndex >= 0 {
|
||||||
@@ -215,3 +211,52 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
|||||||
}
|
}
|
||||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parsePort(value string) (port uint16, err error) {
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
portValue, err := strconv.ParseUint(value, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return uint16(portValue), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||||
|
consumed int, err error,
|
||||||
|
) {
|
||||||
|
_ = fields[consumed] // -m or --match flag already detected
|
||||||
|
consumed++
|
||||||
|
switch fields[consumed] {
|
||||||
|
case "tcp", "udp":
|
||||||
|
consumed++
|
||||||
|
// for now ignore the protocol match since it's auto-loaded
|
||||||
|
// when parsing the -p/--protocol flag, and we don't need to
|
||||||
|
// parse it twice.
|
||||||
|
case "mark":
|
||||||
|
consumed++
|
||||||
|
switch fields[consumed] {
|
||||||
|
case "!":
|
||||||
|
consumed++
|
||||||
|
instruction.mark.invert = true
|
||||||
|
default:
|
||||||
|
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||||
|
ErrIptablesCommandMalformed, fields[2])
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||||
|
ErrIptablesCommandMalformed, fields[consumed])
|
||||||
|
}
|
||||||
|
return consumed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToPorts(value string) (toPorts []uint16, err error) {
|
||||||
|
portStrings := strings.Split(value, ",")
|
||||||
|
toPorts = make([]uint16, len(portStrings))
|
||||||
|
for i, portString := range portStrings {
|
||||||
|
toPorts[i], err = parsePort(portString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return toPorts, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module li
|
|||||||
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||||
// by sending a TCP RST packet, although we want to handle the connection manually.
|
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||||
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||||
addrPort netip.AddrPort, excludeMark int) (
|
src, dst netip.AddrPort, excludeMark int) (
|
||||||
revert func(ctx context.Context) error, err error,
|
revert func(ctx context.Context) error, err error,
|
||||||
) {
|
) {
|
||||||
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||||
@@ -79,11 +79,12 @@ func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
|||||||
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||||
}
|
}
|
||||||
|
|
||||||
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " +
|
||||||
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
|
"--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||||
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
|
instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||||
|
revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark)
|
||||||
run := c.runIptablesInstruction
|
run := c.runIptablesInstruction
|
||||||
if addrPort.Addr().Is6() {
|
if dst.Addr().Is6() {
|
||||||
run = c.runIP6tablesInstruction
|
run = c.runIP6tablesInstruction
|
||||||
}
|
}
|
||||||
revert = func(ctx context.Context) error {
|
revert = func(ctx context.Context) error {
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
|
|||||||
const mtuMargin = 150
|
const mtuMargin = 150
|
||||||
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||||
}
|
}
|
||||||
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger)
|
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
|
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
|
||||||
logger.Debugf("aborting TCP path MTU discovery: %s", err)
|
logger.Debugf("aborting TCP path MTU discovery: %s", err)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Firewall interface {
|
type Firewall interface {
|
||||||
TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort,
|
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort,
|
||||||
excludeMark int) (revert func(ctx context.Context) error, err error)
|
excludeMark int) (revert func(ctx context.Context) error, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+70
-45
@@ -18,26 +18,61 @@ type testUnit struct {
|
|||||||
ok bool
|
ok bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
|
||||||
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
|
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
|
||||||
|
firewall Firewall, logger Logger,
|
||||||
) (mtu uint32, err error) {
|
) (mtu uint32, err error) {
|
||||||
const excludeMark = 4325
|
family := constants.AF_INET
|
||||||
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
|
if dst.Addr().Is6() {
|
||||||
if err != nil {
|
family = constants.AF_INET6
|
||||||
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
|
||||||
}
|
}
|
||||||
defer func() {
|
const excludeMark = 4325
|
||||||
err := revert(ctx)
|
fd, stop, err := startRawSocket(family, excludeMark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warnf("reverting firewall changes: %s", err)
|
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||||
}
|
}
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
tracker := newTracker(fd, dst.Addr().Is4())
|
||||||
|
|
||||||
|
trackerCtx, trackerCancel := context.WithCancel(ctx)
|
||||||
|
defer trackerCancel()
|
||||||
|
trackerErrCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
trackerErrCh <- tracker.listen(trackerCtx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
|
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
|
||||||
|
defer pmtudCancel()
|
||||||
|
type result struct {
|
||||||
|
mtu uint32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
pmtudResultCh := make(chan result)
|
||||||
|
go func() {
|
||||||
|
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU,
|
||||||
|
excludeMark, tryTimeout, tracker, firewall, logger)
|
||||||
|
pmtudResultCh <- result{mtu: mtu, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-trackerErrCh:
|
||||||
|
pmtudCancel()
|
||||||
|
<-pmtudResultCh
|
||||||
|
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||||
|
case res := <-pmtudResultCh:
|
||||||
|
trackerCancel()
|
||||||
|
<-trackerErrCh
|
||||||
|
return res.mtu, res.err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
var errTimedOut = errors.New("timed out")
|
||||||
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
|
|
||||||
|
func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
|
||||||
|
dst netip.AddrPort, minMTU, maxPossibleMTU uint32, excludeMark int,
|
||||||
|
tryTimeout time.Duration, tracker *tracker, firewall Firewall,
|
||||||
|
logger Logger,
|
||||||
) (mtu uint32, err error) {
|
) (mtu uint32, err error) {
|
||||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||||
@@ -50,30 +85,14 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|||||||
tests[i] = testUnit{mtu: mtusToTest[i]}
|
tests[i] = testUnit{mtu: mtusToTest[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
family := constants.AF_INET
|
errCause := fmt.Errorf("%w: after %s", errTimedOut, tryTimeout)
|
||||||
if addrPort.Addr().Is6() {
|
runCtx, runCancel := context.WithTimeoutCause(ctx, tryTimeout, errCause)
|
||||||
family = constants.AF_INET6
|
defer runCancel()
|
||||||
}
|
|
||||||
fd, stop, err := startRawSocket(family, excludeMark)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
|
||||||
}
|
|
||||||
defer stop()
|
|
||||||
|
|
||||||
tracker := newTracker(fd, addrPort.Addr().Is4())
|
|
||||||
|
|
||||||
const timeout = time.Second
|
|
||||||
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
errCh := make(chan error)
|
|
||||||
go func() {
|
|
||||||
errCh <- tracker.listen(runCtx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
for i := range tests {
|
for i := range tests {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
err := runTest(runCtx, fd, tracker, src, dst, tests[i].mtu)
|
err := runTest(runCtx, dst, tests[i].mtu, excludeMark,
|
||||||
|
fd, tracker, firewall, logger)
|
||||||
tests[i].ok = err == nil
|
tests[i].ok = err == nil
|
||||||
doneCh <- struct{}{}
|
doneCh <- struct{}{}
|
||||||
}(i)
|
}(i)
|
||||||
@@ -82,27 +101,33 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|||||||
i := 0
|
i := 0
|
||||||
for i < len(tests) {
|
for i < len(tests) {
|
||||||
select {
|
select {
|
||||||
|
case <-runCtx.Done(): // timeout or parent context canceled
|
||||||
|
err = context.Cause(runCtx)
|
||||||
|
// collect remaining done signals
|
||||||
|
for i < len(tests) {
|
||||||
|
<-doneCh
|
||||||
|
i++
|
||||||
|
}
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
i++
|
i++
|
||||||
case err := <-errCh:
|
|
||||||
if err == nil { // timeout
|
|
||||||
cancel()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, errTimedOut) {
|
||||||
|
// context is canceled but did not timeout after tryTimeout
|
||||||
|
return 0, fmt.Errorf("running MTU tests: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if tests[len(tests)-1].ok {
|
if tests[len(tests)-1].ok {
|
||||||
return tests[len(tests)-1].mtu, nil
|
return tests[len(tests)-1].mtu, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||||
if tests[i].ok {
|
if tests[i].ok {
|
||||||
stop()
|
runCancel() // just to release resources although runCtx is no longer used
|
||||||
cancel()
|
return pathMTUDiscover(ctx, fd, dst,
|
||||||
return pathMTUDiscover(ctx, addrPort,
|
tests[i].mtu, tests[i+1].mtu-1, excludeMark,
|
||||||
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
|
tryTimeout, tracker, firewall, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,8 +64,9 @@ var (
|
|||||||
// Craft and send a raw TCP packet to test the MTU.
|
// Craft and send a raw TCP packet to test the MTU.
|
||||||
// It expects either an RST reply (if no server is listening)
|
// It expects either an RST reply (if no server is listening)
|
||||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||||
func runTest(ctx context.Context, fd fileDescriptor,
|
func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
||||||
tracker *tracker, dst netip.AddrPort, mtu uint32,
|
excludeMark int, fd fileDescriptor, tracker *tracker,
|
||||||
|
firewall Firewall, logger Logger,
|
||||||
) error {
|
) error {
|
||||||
const proto = constants.IPPROTO_TCP
|
const proto = constants.IPPROTO_TCP
|
||||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||||
@@ -74,6 +75,20 @@ func runTest(ctx context.Context, fd fileDescriptor,
|
|||||||
}
|
}
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
|
revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// we don't want to skip reverting the firewall changes
|
||||||
|
// even if the context is already expired, so we use a
|
||||||
|
// background context here.
|
||||||
|
err := revert(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("reverting firewall changes: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ch := make(chan []byte)
|
ch := make(chan []byte)
|
||||||
abort := make(chan struct{})
|
abort := make(chan struct{})
|
||||||
defer close(abort)
|
defer close(abort)
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
|
"github.com/qdm12/gluetun/internal/firewall"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/log"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_PathMTUDiscover(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
noopLogger := log.New(log.SetLevel(log.LevelDebug))
|
||||||
|
|
||||||
|
cmder := command.New()
|
||||||
|
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||||
|
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||||
|
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "creating firewall config")
|
||||||
|
|
||||||
|
dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||||
|
const minMTU = constants.MinIPv6MTU
|
||||||
|
const maxMTU = constants.MaxEthernetFrameSize
|
||||||
|
const tryTimeout = time.Second
|
||||||
|
mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger)
|
||||||
|
require.NoError(t, err, "discovering path MTU")
|
||||||
|
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
|
||||||
|
t.Logf("discovered path MTU to %s is %d", dst, mtu)
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/firewall"
|
"github.com/qdm12/gluetun/internal/firewall"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
"github.com/qdm12/gluetun/internal/routing"
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -24,11 +25,7 @@ import (
|
|||||||
func Test_runTest(t *testing.T) {
|
func Test_runTest(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
serverAddrs := map[string]netip.AddrPort{
|
localNonListenPort := reserveClosedPort(t)
|
||||||
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
|
||||||
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
|
||||||
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
|
||||||
}
|
|
||||||
|
|
||||||
noopLogger := &noopLogger{}
|
noopLogger := &noopLogger{}
|
||||||
|
|
||||||
@@ -39,17 +36,6 @@ func Test_runTest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.NoError(t, err, "creating firewall config")
|
require.NoError(t, err, "creating firewall config")
|
||||||
|
|
||||||
// Prevent Kernel from sending RST packets back to servers
|
|
||||||
const excludeMark = 4324
|
|
||||||
for _, addrPort := range serverAddrs {
|
|
||||||
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := revert(context.Background())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
netlinker := netlink.New(noopLogger)
|
netlinker := netlink.New(noopLogger)
|
||||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||||
@@ -59,6 +45,7 @@ func Test_runTest(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(t.Context())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
const family = constants.AF_INET
|
const family = constants.AF_INET
|
||||||
|
const excludeMark = 4545
|
||||||
fd, stop, err := startRawSocket(family, excludeMark)
|
fd, stop, err := startRawSocket(family, excludeMark)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -69,6 +56,11 @@ func Test_runTest(t *testing.T) {
|
|||||||
trackerCh <- tracker.listen(ctx)
|
trackerCh <- tracker.listen(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Our local ethernet MTU could be 1500, and the server could advertise
|
||||||
|
// an MSS of 1400, but the real link to the server could have an MTU of 1300,
|
||||||
|
// so we need to adjust our test so it passes. We are not actually path MTU
|
||||||
|
// discovering here, just testing that we can receive the expected TCP packets
|
||||||
|
// for a given MTU.
|
||||||
const mtuSafetyBuffer = 200
|
const mtuSafetyBuffer = 200
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -80,48 +72,36 @@ func Test_runTest(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
dst func(t *testing.T) netip.AddrPort
|
server netip.AddrPort
|
||||||
mtu uint32
|
mtu uint32
|
||||||
success bool
|
success bool
|
||||||
}{
|
}{
|
||||||
"local_not_listening": {
|
"local_not_listening": {
|
||||||
timeout: time.Hour,
|
timeout: time.Hour,
|
||||||
dst: func(t *testing.T) netip.AddrPort {
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort),
|
||||||
t.Helper()
|
|
||||||
port := reserveClosedPort(t)
|
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
|
|
||||||
},
|
|
||||||
mtu: loopbackMTU,
|
mtu: loopbackMTU,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"remote_not_listening": {
|
"remote_not_listening": {
|
||||||
timeout: 50 * time.Millisecond,
|
timeout: 50 * time.Millisecond,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
},
|
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
|
||||||
},
|
},
|
||||||
"1.1.1.1:443": {
|
"1.1.1.1:443": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||||
return serverAddrs["cloudflare-https"]
|
|
||||||
},
|
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"1.1.1.1:80": {
|
"1.1.1.1:80": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||||
return serverAddrs["cloudflare-http"]
|
|
||||||
},
|
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"8.8.8.8:443": {
|
"8.8.8.8:443": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||||
return serverAddrs["google-https"]
|
|
||||||
},
|
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
@@ -131,11 +111,24 @@ func Test_runTest(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
dst := testCase.dst(t)
|
dst := testCase.server
|
||||||
|
|
||||||
|
const proto = constants.IPPROTO_TCP
|
||||||
|
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||||
|
require.NoError(t, err, "getting source address to reach remote server %s", dst)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := revert(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
err = runTest(ctx, dst, testCase.mtu, excludeMark,
|
||||||
|
fd, tracker, fw, noopLogger)
|
||||||
if testCase.success {
|
if testCase.success {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@@ -230,4 +223,5 @@ func (l *noopLogger) Debug(_ string) {}
|
|||||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||||
func (l *noopLogger) Info(_ string) {}
|
func (l *noopLogger) Info(_ string) {}
|
||||||
func (l *noopLogger) Warn(_ string) {}
|
func (l *noopLogger) Warn(_ string) {}
|
||||||
|
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||||
func (l *noopLogger) Error(_ string) {}
|
func (l *noopLogger) Error(_ string) {}
|
||||||
|
|||||||
Reference in New Issue
Block a user