From 8d8647090506f2520d3f45c6c089957cf053c1e1 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 19 Feb 2026 14:03:46 +0000 Subject: [PATCH] feat(pmtud/tcp): use the TCP server with highest MSS to run MTU tests --- internal/pmtud/ip/ipheader.go | 7 ++ internal/pmtud/pmtud.go | 46 +++---- internal/pmtud/tcp/helpers_test.go | 32 +++++ internal/pmtud/tcp/mss.go | 138 +++++++++++++++++++++ internal/pmtud/tcp/mss_test.go | 59 +++++++++ internal/pmtud/tcp/multi.go | 69 ++++++++--- internal/pmtud/tcp/tcp.go | 4 +- internal/pmtud/tcp/tcp_integration_test.go | 11 +- internal/pmtud/tcp/tcp_test.go | 12 +- internal/pmtud/tcp/tcpheader.go | 4 +- 10 files changed, 323 insertions(+), 59 deletions(-) create mode 100644 internal/pmtud/tcp/mss.go create mode 100644 internal/pmtud/tcp/mss_test.go diff --git a/internal/pmtud/ip/ipheader.go b/internal/pmtud/ip/ipheader.go index eb662607..b0139b9b 100644 --- a/internal/pmtud/ip/ipheader.go +++ b/internal/pmtud/ip/ipheader.go @@ -7,6 +7,13 @@ import ( "github.com/qdm12/gluetun/internal/pmtud/constants" ) +func HeaderLength(ipv4 bool) uint32 { + if ipv4 { + return constants.IPv4HeaderLength + } + return constants.IPv6HeaderLength +} + func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte { ipHeader := make([]byte, constants.IPv4HeaderLength) const version byte = 4 diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index 937aad56..716f416f 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -61,32 +61,24 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net } } - for _, addrPort := range tcpAddrs { - minMTU := constants.MinIPv4MTU - if addrPort.Addr().Is6() { - minMTU = constants.MinIPv6MTU - } - if icmpSuccess { - const mtuMargin = 150 - minMTU = max(maxPossibleMTU-mtuMargin, minMTU) - } - mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger) - if err != nil { - if errors.Is(err, firewall.ErrMarkMatchModuleMissing) { - logger.Debugf("aborting TCP path MTU discovery: %s", err) - if icmpSuccess { - return maxPossibleMTU, nil // only rely on ICMP PMTUD results - } - return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP) - } - logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err) - continue - } - logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu) - return mtu, nil + minMTU := constants.MinIPv4MTU + if tcpAddrs[0].Addr().Is6() { + minMTU = constants.MinIPv6MTU } - - // TCP PMTUD failed for all addresses for external reasons, - // so do not take the risk and return an error. - return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err) + if icmpSuccess { + const mtuMargin = 150 + minMTU = max(maxPossibleMTU-mtuMargin, minMTU) + } + mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger) + if err != nil { + if errors.Is(err, firewall.ErrMarkMatchModuleMissing) { + logger.Debugf("aborting TCP path MTU discovery: %s", err) + if icmpSuccess { + return maxPossibleMTU, nil // only rely on ICMP PMTUD results + } + } + return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP) + } + logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu) + return mtu, nil } diff --git a/internal/pmtud/tcp/helpers_test.go b/internal/pmtud/tcp/helpers_test.go index 53d4fbb3..80b21614 100644 --- a/internal/pmtud/tcp/helpers_test.go +++ b/internal/pmtud/tcp/helpers_test.go @@ -3,8 +3,11 @@ package tcp import ( "errors" "fmt" + "sync" "testing" + "github.com/qdm12/gluetun/internal/command" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/routing" @@ -14,6 +17,35 @@ import ( "golang.org/x/sys/unix" ) +// testFirewall must be global to prevent parallel tests from interfering +// with each other since they would interact with the same filter table. +// The first test to use should initialize it, and the rest will reuse it. +var ( + testFirewall *firewall.Config //nolint:gochecknoglobals + testFirewallOnce sync.Once //nolint:gochecknoglobals +) + +// getFirewall returns a Firewall instance, initializing it if needed. If +// iptables is not supported, it skips the test. +func getFirewall(t *testing.T) *firewall.Config { + t.Helper() + + testFirewallOnce.Do(func() { + noopLogger := &noopLogger{} + cmder := command.New() + var err error + testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) + if errors.Is(err, firewall.ErrIPTablesNotSupported) { + t.Skip("iptables not installed, skipping TCP PMTUD tests") + } + require.NoError(t, err, "creating firewall config") + }) + if testFirewall == nil { + t.Skip("iptables not installed, skipping TCP PMTUD tests") + } + return testFirewall +} + type noopLogger struct{} func (l *noopLogger) Patch(_ ...log.Option) {} diff --git a/internal/pmtud/tcp/mss.go b/internal/pmtud/tcp/mss.go new file mode 100644 index 00000000..518403b4 --- /dev/null +++ b/internal/pmtud/tcp/mss.go @@ -0,0 +1,138 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/ip" +) + +// findHighestMSSDestination finds the destination with the highest +// MSS amongst the provided destinations. +func findHighestMSSDestination(ctx context.Context, fd fileDescriptor, + dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32, + timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) ( + dst netip.AddrPort, mss uint32, err error, +) { + type result struct { + dst netip.AddrPort + mss uint32 + err error + } + resultCh := make(chan result) + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + for _, dst := range dsts { + go func(dst netip.AddrPort) { + mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger) + resultCh <- result{dst: dst, mss: mss, err: err} + }(dst) + } + + for range dsts { + result := <-resultCh + if result.err != nil { + switch { + case err != nil: // error already occurred for another findMSS goroutine + case errors.Is(result.err, firewall.ErrMarkMatchModuleMissing): + err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err) + default: // another error not due to the match module missing + logger.Debugf("finding MSS for %s failed: %s", result.dst, result.err) + } + continue + } + ipHeaderLength := ip.HeaderLength(result.dst.Addr().Is4()) + maxNeededMSS := maxPossibleMTU - ipHeaderLength - constants.BaseTCPHeaderLength + switch { + case result.mss >= maxNeededMSS: + logger.Debugf("%s has an MSS of %d bytes which is equal or higher than "+ + "the maximum needed MSS of %d bytes for the maximum possible MTU of %d bytes", + result.dst, result.mss, maxNeededMSS, maxPossibleMTU) + return result.dst, result.mss, nil + case result.mss > mss: + mss = result.mss + dst = result.dst + } + } + + maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss + logger.Debugf("server %s has the highest MSS %d allowing to test the MTU up to %d", + dst, mss, maxPossibleMTU) + return dst, mss, nil +} + +var errMSSNotFound = errors.New("MSS option not found in reply") + +func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort, + excludeMark int, tracker *tracker, firewall Firewall, logger Logger) ( + mss uint32, err error, +) { + const proto = constants.IPPROTO_TCP + src, cleanup, err := ip.SrcAddr(dst, proto) + if err != nil { + return 0, fmt.Errorf("getting source address: %w", err) + } + defer cleanup() + + revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark) + if err != nil { + return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err) + } + defer func() { + // we don't want to skip reverting the firewall changes + // even if the context is already expired, so we use a + // background context here. + err := revert(context.Background()) + if err != nil { + logger.Warnf("reverting firewall changes: %s", err) + } + }() + + ch := make(chan []byte) + abort := make(chan struct{}) + defer close(abort) + tracker.register(src.Port(), dst.Port(), ch, abort) + defer tracker.unregister(src.Port(), dst.Port()) + + dstSockAddr := makeSockAddr(dst) + + synPacket, synSeq := createSYNPacket(src, dst, 0) + const sendToFlags = 0 + err = sendTo(fd, synPacket, sendToFlags, dstSockAddr) + if err != nil { + return 0, fmt.Errorf("sending SYN packet: %w", err) + } + + var reply []byte + select { + case <-ctx.Done(): + _ = sendRST(fd, src, dst, synSeq+1) + return 0, ctx.Err() + case reply = <-ch: + } + + replyHeader, err := parseTCPHeader(reply) + switch { + case err != nil: + return 0, fmt.Errorf("parsing reply TCP header: %w", err) + case replyHeader.typ != packetTypeSYNACK: + return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ) + case replyHeader.ack != synSeq+1: + return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack) + case replyHeader.options.mss == 0: + return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound) + } + + err = sendRST(fd, src, dst, replyHeader.ack) + if err != nil { + return 0, fmt.Errorf("sending RST packet: %w", err) + } + + return replyHeader.options.mss, nil +} diff --git a/internal/pmtud/tcp/mss_test.go b/internal/pmtud/tcp/mss_test.go new file mode 100644 index 00000000..b92c8926 --- /dev/null +++ b/internal/pmtud/tcp/mss_test.go @@ -0,0 +1,59 @@ +//go:build linux + +package tcp + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_findHighestMSSDestination(t *testing.T) { + t.Parallel() + + netlinker := netlink.New(&noopLogger{}) + defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) + require.NoError(t, err, "finding default IPv4 route MTU") + + ctx, cancel := context.WithCancel(t.Context()) + + const family = constants.AF_INET + fd, stop, err := startRawSocket(family, excludeMark) + require.NoError(t, err) + + const ipv4 = true + tracker := newTracker(fd, ipv4) + trackerCh := make(chan error) + go func() { + trackerCh <- tracker.listen(ctx) + }() + + t.Cleanup(func() { + stop() + cancel() // stop listening + err = <-trackerCh + require.NoError(t, err) + }) + + dsts := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), + } + const timeout = time.Second + fw := getFirewall(t) + logger := &noopLogger{} + + dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts, + excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger) + require.NoError(t, err, "finding highest MSS destination") + assert.Contains(t, dsts, dst, "destination should be in the provided list") + assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000") + assert.LessOrEqual(t, mss, constants.MaxEthernetFrameSize, + "MSS should be less than or equal to the maximum Ethernet frame size ") +} diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go index a36ec259..275dfe3f 100644 --- a/internal/pmtud/tcp/multi.go +++ b/internal/pmtud/tcp/multi.go @@ -8,6 +8,7 @@ import ( "time" "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/test" ) @@ -18,22 +19,31 @@ type testUnit struct { ok bool } -func PathMTUDiscover(ctx context.Context, dst netip.AddrPort, +const excludeMark = 4545 + +// PathMTUDiscover first finds the destination TCP server with the highest +// available MSS, in order to be able to test the highest possible MTU. +// If a server has an MSS larger than maxPossibleMTU, this one is used. +// It then performs a binary search of the MTU between minMTU and maxPossibleMTU, +// by sending IP packets with the Don't Fragment bit set and checking if they +// are received or not, exploiting the stateful nature of TCP to be able to +// correlate replies to the sent packets. +// Note all dsts must be of the same IP family (all IPv4 or all IPv6). +func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, minMTU, maxPossibleMTU uint32, tryTimeout time.Duration, firewall Firewall, logger Logger, ) (mtu uint32, err error) { family := constants.AF_INET - if dst.Addr().Is6() { + if dsts[0].Addr().Is6() { family = constants.AF_INET6 } - const excludeMark = 4325 fd, stop, err := startRawSocket(family, excludeMark) if err != nil { return 0, fmt.Errorf("starting raw socket: %w", err) } defer stop() - tracker := newTracker(fd, dst.Addr().Is4()) + tracker := newTracker(fd, family == constants.AF_INET) trackerCtx, trackerCancel := context.WithCancel(ctx) defer trackerCancel() @@ -42,28 +52,59 @@ func PathMTUDiscover(ctx context.Context, dst netip.AddrPort, trackerErrCh <- tracker.listen(trackerCtx) }() - pmtudCtx, pmtudCancel := context.WithCancel(ctx) - defer pmtudCancel() - type result struct { + type mssResult struct { + dst netip.AddrPort + mss uint32 + err error + } + mssResultCh := make(chan mssResult) + + ctx, cancel := context.WithTimeout(ctx, tryTimeout) + defer cancel() + go func() { + dst, mss, err := findHighestMSSDestination(ctx, fd, dsts, excludeMark, + maxPossibleMTU, tryTimeout, tracker, firewall, logger) + mssResultCh <- mssResult{dst: dst, mss: mss, err: err} + }() + var highestMSSDst netip.AddrPort + select { + case err = <-trackerErrCh: + cancel() + <-mssResultCh + return 0, fmt.Errorf("listening for TCP replies: %w", err) + case result := <-mssResultCh: + if result.err != nil { + trackerCancel() + <-trackerErrCh + return 0, fmt.Errorf("finding MSS: %w", result.err) + } + highestMSSDst = result.dst + ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4()) + maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss + } + + type pmtudResult struct { mtu uint32 err error } - pmtudResultCh := make(chan result) + resultCh := make(chan pmtudResult) + ctx, cancel = context.WithCancel(ctx) + defer cancel() go func() { - mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU, + mtu, err := pathMTUDiscover(ctx, fd, highestMSSDst, minMTU, maxPossibleMTU, excludeMark, tryTimeout, tracker, firewall, logger) - pmtudResultCh <- result{mtu: mtu, err: err} + resultCh <- pmtudResult{mtu: mtu, err: err} }() select { case err = <-trackerErrCh: - pmtudCancel() - <-pmtudResultCh + cancel() + <-resultCh return 0, fmt.Errorf("listening for TCP replies: %w", err) - case res := <-pmtudResultCh: + case result := <-resultCh: trackerCancel() <-trackerErrCh - return res.mtu, res.err + return result.mtu, result.err } } diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go index 6b9cc171..0bc00cf8 100644 --- a/internal/pmtud/tcp/tcp.go +++ b/internal/pmtud/tcp/tcp.go @@ -129,8 +129,8 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32, if firstReplyHeader.options.mss != 0 { // If the server sent an MSS option, make sure our test packet is not larger than that MSS. tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength - if tcpDataLength > uint32(firstReplyHeader.options.mss) { - diff := tcpDataLength - uint32(firstReplyHeader.options.mss) + if tcpDataLength > firstReplyHeader.options.mss { + diff := tcpDataLength - firstReplyHeader.options.mss minMTU := constants.MinIPv4MTU if dst.Addr().Is6() { minMTU = constants.MinIPv6MTU diff --git a/internal/pmtud/tcp/tcp_integration_test.go b/internal/pmtud/tcp/tcp_integration_test.go index 3c05a723..d3374ec0 100644 --- a/internal/pmtud/tcp/tcp_integration_test.go +++ b/internal/pmtud/tcp/tcp_integration_test.go @@ -27,12 +27,17 @@ func Test_PathMTUDiscover(t *testing.T) { } require.NoError(t, err, "creating firewall config") - dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) + dsts := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 53), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), + } const minMTU = constants.MinIPv6MTU const maxMTU = constants.MaxEthernetFrameSize const tryTimeout = time.Second - mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger) + mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger) require.NoError(t, err, "discovering path MTU") assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0") - t.Logf("discovered path MTU to %s is %d", dst, mtu) + t.Logf("discovered path MTU is %d", mtu) } diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go index 5947436c..d644bd39 100644 --- a/internal/pmtud/tcp/tcp_test.go +++ b/internal/pmtud/tcp/tcp_test.go @@ -4,14 +4,11 @@ package tcp import ( "context" - "errors" "net/netip" "testing" "time" gomock "github.com/golang/mock/gomock" - "github.com/qdm12/gluetun/internal/command" - "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/ip" @@ -26,13 +23,6 @@ func Test_runTest(t *testing.T) { noopLogger := &noopLogger{} - cmder := command.New() - fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) - if errors.Is(err, firewall.ErrIPTablesNotSupported) { - t.Skip("iptables not installed, skipping TCP PMTUD tests") - } - require.NoError(t, err, "creating firewall config") - netlinker := netlink.New(noopLogger) loopbackMTU, err := findLoopbackMTU(netlinker) require.NoError(t, err, "finding loopback IPv4 MTU") @@ -42,7 +32,6 @@ func Test_runTest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) const family = constants.AF_INET - const excludeMark = 4545 fd, stop, err := startRawSocket(family, excludeMark) require.NoError(t, err) @@ -116,6 +105,7 @@ func Test_runTest(t *testing.T) { require.NoError(t, err, "getting source address to reach remote server %s", dst) t.Cleanup(cleanup) + fw := getFirewall(t) revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark) require.NoError(t, err) t.Cleanup(func() { diff --git a/internal/pmtud/tcp/tcpheader.go b/internal/pmtud/tcp/tcpheader.go index c3b3c111..489052f8 100644 --- a/internal/pmtud/tcp/tcpheader.go +++ b/internal/pmtud/tcp/tcpheader.go @@ -199,7 +199,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) { } type options struct { - mss uint16 + mss uint32 windowScale *uint8 // Pointer to differentiate between 0 and "not present" sackPermitted bool timestamps *optionTimestamps @@ -266,7 +266,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) { return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d", ErrTCPOptionMSSInvalid, i, length, expectedLength) } - parsed.mss = binary.BigEndian.Uint16(data) + parsed.mss = uint32(binary.BigEndian.Uint16(data)) case optionTypeWindowScale: const expectedLength = 3 if length != expectedLength {