feat(pmtud/tcp): use the TCP server with highest MSS to run MTU tests

This commit is contained in:
Quentin McGaw
2026-02-19 14:03:46 +00:00
parent fb85ae79d1
commit 8d86470905
10 changed files with 323 additions and 59 deletions
+7
View File
@@ -7,6 +7,13 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func HeaderLength(ipv4 bool) uint32 {
if ipv4 {
return constants.IPv4HeaderLength
}
return constants.IPv6HeaderLength
}
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
ipHeader := make([]byte, constants.IPv4HeaderLength)
const version byte = 4
+4 -12
View File
@@ -61,32 +61,24 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
}
}
for _, addrPort := range tcpAddrs {
minMTU := constants.MinIPv4MTU
if addrPort.Addr().Is6() {
if tcpAddrs[0].Addr().Is6() {
minMTU = constants.MinIPv6MTU
}
if icmpSuccess {
const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
}
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
}
}
return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP)
}
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
continue
}
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
return mtu, nil
}
// TCP PMTUD failed for all addresses for external reasons,
// so do not take the risk and return an error.
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
}
+32
View File
@@ -3,8 +3,11 @@ package tcp
import (
"errors"
"fmt"
"sync"
"testing"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing"
@@ -14,6 +17,35 @@ import (
"golang.org/x/sys/unix"
)
// testFirewall must be global to prevent parallel tests from interfering
// with each other since they would interact with the same filter table.
// The first test to use should initialize it, and the rest will reuse it.
var (
testFirewall *firewall.Config //nolint:gochecknoglobals
testFirewallOnce sync.Once //nolint:gochecknoglobals
)
// getFirewall returns a Firewall instance, initializing it if needed. If
// iptables is not supported, it skips the test.
func getFirewall(t *testing.T) *firewall.Config {
t.Helper()
testFirewallOnce.Do(func() {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
})
if testFirewall == nil {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
return testFirewall
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
+138
View File
@@ -0,0 +1,138 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
// findHighestMSSDestination finds the destination with the highest
// MSS amongst the provided destinations.
func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
dst netip.AddrPort, mss uint32, err error,
) {
type result struct {
dst netip.AddrPort
mss uint32
err error
}
resultCh := make(chan result)
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for _, dst := range dsts {
go func(dst netip.AddrPort) {
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
resultCh <- result{dst: dst, mss: mss, err: err}
}(dst)
}
for range dsts {
result := <-resultCh
if result.err != nil {
switch {
case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, firewall.ErrMarkMatchModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
default: // another error not due to the match module missing
logger.Debugf("finding MSS for %s failed: %s", result.dst, result.err)
}
continue
}
ipHeaderLength := ip.HeaderLength(result.dst.Addr().Is4())
maxNeededMSS := maxPossibleMTU - ipHeaderLength - constants.BaseTCPHeaderLength
switch {
case result.mss >= maxNeededMSS:
logger.Debugf("%s has an MSS of %d bytes which is equal or higher than "+
"the maximum needed MSS of %d bytes for the maximum possible MTU of %d bytes",
result.dst, result.mss, maxNeededMSS, maxPossibleMTU)
return result.dst, result.mss, nil
case result.mss > mss:
mss = result.mss
dst = result.dst
}
}
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
logger.Debugf("server %s has the highest MSS %d allowing to test the MTU up to %d",
dst, mss, maxPossibleMTU)
return dst, mss, nil
}
var errMSSNotFound = errors.New("MSS option not found in reply")
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
mss uint32, err error,
) {
const proto = constants.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
if err != nil {
return 0, fmt.Errorf("getting source address: %w", err)
}
defer cleanup()
revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
if err != nil {
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
}
defer func() {
// we don't want to skip reverting the firewall changes
// even if the context is already expired, so we use a
// background context here.
err := revert(context.Background())
if err != nil {
logger.Warnf("reverting firewall changes: %s", err)
}
}()
ch := make(chan []byte)
abort := make(chan struct{})
defer close(abort)
tracker.register(src.Port(), dst.Port(), ch, abort)
defer tracker.unregister(src.Port(), dst.Port())
dstSockAddr := makeSockAddr(dst)
synPacket, synSeq := createSYNPacket(src, dst, 0)
const sendToFlags = 0
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
if err != nil {
return 0, fmt.Errorf("sending SYN packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
_ = sendRST(fd, src, dst, synSeq+1)
return 0, ctx.Err()
case reply = <-ch:
}
replyHeader, err := parseTCPHeader(reply)
switch {
case err != nil:
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
case replyHeader.typ != packetTypeSYNACK:
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ)
case replyHeader.ack != synSeq+1:
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack)
case replyHeader.options.mss == 0:
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound)
}
err = sendRST(fd, src, dst, replyHeader.ack)
if err != nil {
return 0, fmt.Errorf("sending RST packet: %w", err)
}
return replyHeader.options.mss, nil
}
+59
View File
@@ -0,0 +1,59 @@
//go:build linux
package tcp
import (
"context"
"net/netip"
"testing"
"time"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_findHighestMSSDestination(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopLogger{})
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
dsts := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
const timeout = time.Second
fw := getFirewall(t)
logger := &noopLogger{}
dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts,
excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger)
require.NoError(t, err, "finding highest MSS destination")
assert.Contains(t, dsts, dst, "destination should be in the provided list")
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
assert.LessOrEqual(t, mss, constants.MaxEthernetFrameSize,
"MSS should be less than or equal to the maximum Ethernet frame size ")
}
+55 -14
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
@@ -18,22 +19,31 @@ type testUnit struct {
ok bool
}
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
const excludeMark = 4545
// PathMTUDiscover first finds the destination TCP server with the highest
// available MSS, in order to be able to test the highest possible MTU.
// If a server has an MSS larger than maxPossibleMTU, this one is used.
// It then performs a binary search of the MTU between minMTU and maxPossibleMTU,
// by sending IP packets with the Don't Fragment bit set and checking if they
// are received or not, exploiting the stateful nature of TCP to be able to
// correlate replies to the sent packets.
// Note all dsts must be of the same IP family (all IPv4 or all IPv6).
func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger,
) (mtu uint32, err error) {
family := constants.AF_INET
if dst.Addr().Is6() {
if dsts[0].Addr().Is6() {
family = constants.AF_INET6
}
const excludeMark = 4325
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, dst.Addr().Is4())
tracker := newTracker(fd, family == constants.AF_INET)
trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel()
@@ -42,28 +52,59 @@ func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
trackerErrCh <- tracker.listen(trackerCtx)
}()
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
defer pmtudCancel()
type result struct {
type mssResult struct {
dst netip.AddrPort
mss uint32
err error
}
mssResultCh := make(chan mssResult)
ctx, cancel := context.WithTimeout(ctx, tryTimeout)
defer cancel()
go func() {
dst, mss, err := findHighestMSSDestination(ctx, fd, dsts, excludeMark,
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
}()
var highestMSSDst netip.AddrPort
select {
case err = <-trackerErrCh:
cancel()
<-mssResultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case result := <-mssResultCh:
if result.err != nil {
trackerCancel()
<-trackerErrCh
return 0, fmt.Errorf("finding MSS: %w", result.err)
}
highestMSSDst = result.dst
ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4())
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
}
type pmtudResult struct {
mtu uint32
err error
}
pmtudResultCh := make(chan result)
resultCh := make(chan pmtudResult)
ctx, cancel = context.WithCancel(ctx)
defer cancel()
go func() {
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU,
mtu, err := pathMTUDiscover(ctx, fd, highestMSSDst, minMTU, maxPossibleMTU,
excludeMark, tryTimeout, tracker, firewall, logger)
pmtudResultCh <- result{mtu: mtu, err: err}
resultCh <- pmtudResult{mtu: mtu, err: err}
}()
select {
case err = <-trackerErrCh:
pmtudCancel()
<-pmtudResultCh
cancel()
<-resultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case res := <-pmtudResultCh:
case result := <-resultCh:
trackerCancel()
<-trackerErrCh
return res.mtu, res.err
return result.mtu, result.err
}
}
+2 -2
View File
@@ -129,8 +129,8 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
if firstReplyHeader.options.mss != 0 {
// If the server sent an MSS option, make sure our test packet is not larger than that MSS.
tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength
if tcpDataLength > uint32(firstReplyHeader.options.mss) {
diff := tcpDataLength - uint32(firstReplyHeader.options.mss)
if tcpDataLength > firstReplyHeader.options.mss {
diff := tcpDataLength - firstReplyHeader.options.mss
minMTU := constants.MinIPv4MTU
if dst.Addr().Is6() {
minMTU = constants.MinIPv6MTU
+8 -3
View File
@@ -27,12 +27,17 @@ func Test_PathMTUDiscover(t *testing.T) {
}
require.NoError(t, err, "creating firewall config")
dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
dsts := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 53),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
const minMTU = constants.MinIPv6MTU
const maxMTU = constants.MaxEthernetFrameSize
const tryTimeout = time.Second
mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger)
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger)
require.NoError(t, err, "discovering path MTU")
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
t.Logf("discovered path MTU to %s is %d", dst, mtu)
t.Logf("discovered path MTU is %d", mtu)
}
+1 -11
View File
@@ -4,14 +4,11 @@ package tcp
import (
"context"
"errors"
"net/netip"
"testing"
"time"
gomock "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
@@ -26,13 +23,6 @@ func Test_runTest(t *testing.T) {
noopLogger := &noopLogger{}
cmder := command.New()
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
@@ -42,7 +32,6 @@ func Test_runTest(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET
const excludeMark = 4545
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err)
@@ -116,6 +105,7 @@ func Test_runTest(t *testing.T) {
require.NoError(t, err, "getting source address to reach remote server %s", dst)
t.Cleanup(cleanup)
fw := getFirewall(t)
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
require.NoError(t, err)
t.Cleanup(func() {
+2 -2
View File
@@ -199,7 +199,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
}
type options struct {
mss uint16
mss uint32
windowScale *uint8 // Pointer to differentiate between 0 and "not present"
sackPermitted bool
timestamps *optionTimestamps
@@ -266,7 +266,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
ErrTCPOptionMSSInvalid, i, length, expectedLength)
}
parsed.mss = binary.BigEndian.Uint16(data)
parsed.mss = uint32(binary.BigEndian.Uint16(data))
case optionTypeWindowScale:
const expectedLength = 3
if length != expectedLength {