feat(pmtud/tcp): use the TCP server with highest MSS to run MTU tests

This commit is contained in:
Quentin McGaw
2026-02-19 14:03:46 +00:00
parent fb85ae79d1
commit 8d86470905
10 changed files with 323 additions and 59 deletions
+7
View File
@@ -7,6 +7,13 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
) )
func HeaderLength(ipv4 bool) uint32 {
if ipv4 {
return constants.IPv4HeaderLength
}
return constants.IPv6HeaderLength
}
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte { func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
ipHeader := make([]byte, constants.IPv4HeaderLength) ipHeader := make([]byte, constants.IPv4HeaderLength)
const version byte = 4 const version byte = 4
+19 -27
View File
@@ -61,32 +61,24 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
} }
} }
for _, addrPort := range tcpAddrs { minMTU := constants.MinIPv4MTU
minMTU := constants.MinIPv4MTU if tcpAddrs[0].Addr().Is6() {
if addrPort.Addr().Is6() { minMTU = constants.MinIPv6MTU
minMTU = constants.MinIPv6MTU
}
if icmpSuccess {
const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
}
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
}
return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP)
}
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
continue
}
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
return mtu, nil
} }
if icmpSuccess {
// TCP PMTUD failed for all addresses for external reasons, const mtuMargin = 150
// so do not take the risk and return an error. minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err) }
mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
if err != nil {
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
}
}
return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP)
}
logger.Debugf("TCP path MTU discovery found maximum valid MTU %d", mtu)
return mtu, nil
} }
+32
View File
@@ -3,8 +3,11 @@ package tcp
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"testing" "testing"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
@@ -14,6 +17,35 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// testFirewall must be global to prevent parallel tests from interfering
// with each other since they would interact with the same filter table.
// The first test to use should initialize it, and the rest will reuse it.
var (
testFirewall *firewall.Config //nolint:gochecknoglobals
testFirewallOnce sync.Once //nolint:gochecknoglobals
)
// getFirewall returns a Firewall instance, initializing it if needed. If
// iptables is not supported, it skips the test.
func getFirewall(t *testing.T) *firewall.Config {
t.Helper()
testFirewallOnce.Do(func() {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
})
if testFirewall == nil {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
return testFirewall
}
type noopLogger struct{} type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {} func (l *noopLogger) Patch(_ ...log.Option) {}
+138
View File
@@ -0,0 +1,138 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
// findHighestMSSDestination finds the destination with the highest
// MSS amongst the provided destinations.
func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
dst netip.AddrPort, mss uint32, err error,
) {
type result struct {
dst netip.AddrPort
mss uint32
err error
}
resultCh := make(chan result)
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for _, dst := range dsts {
go func(dst netip.AddrPort) {
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
resultCh <- result{dst: dst, mss: mss, err: err}
}(dst)
}
for range dsts {
result := <-resultCh
if result.err != nil {
switch {
case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, firewall.ErrMarkMatchModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
default: // another error not due to the match module missing
logger.Debugf("finding MSS for %s failed: %s", result.dst, result.err)
}
continue
}
ipHeaderLength := ip.HeaderLength(result.dst.Addr().Is4())
maxNeededMSS := maxPossibleMTU - ipHeaderLength - constants.BaseTCPHeaderLength
switch {
case result.mss >= maxNeededMSS:
logger.Debugf("%s has an MSS of %d bytes which is equal or higher than "+
"the maximum needed MSS of %d bytes for the maximum possible MTU of %d bytes",
result.dst, result.mss, maxNeededMSS, maxPossibleMTU)
return result.dst, result.mss, nil
case result.mss > mss:
mss = result.mss
dst = result.dst
}
}
maxPossibleMTU = ip.HeaderLength(dst.Addr().Is4()) + constants.BaseTCPHeaderLength + mss
logger.Debugf("server %s has the highest MSS %d allowing to test the MTU up to %d",
dst, mss, maxPossibleMTU)
return dst, mss, nil
}
var errMSSNotFound = errors.New("MSS option not found in reply")
func findMSS(ctx context.Context, fd fileDescriptor, dst netip.AddrPort,
excludeMark int, tracker *tracker, firewall Firewall, logger Logger) (
mss uint32, err error,
) {
const proto = constants.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
if err != nil {
return 0, fmt.Errorf("getting source address: %w", err)
}
defer cleanup()
revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
if err != nil {
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
}
defer func() {
// we don't want to skip reverting the firewall changes
// even if the context is already expired, so we use a
// background context here.
err := revert(context.Background())
if err != nil {
logger.Warnf("reverting firewall changes: %s", err)
}
}()
ch := make(chan []byte)
abort := make(chan struct{})
defer close(abort)
tracker.register(src.Port(), dst.Port(), ch, abort)
defer tracker.unregister(src.Port(), dst.Port())
dstSockAddr := makeSockAddr(dst)
synPacket, synSeq := createSYNPacket(src, dst, 0)
const sendToFlags = 0
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
if err != nil {
return 0, fmt.Errorf("sending SYN packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
_ = sendRST(fd, src, dst, synSeq+1)
return 0, ctx.Err()
case reply = <-ch:
}
replyHeader, err := parseTCPHeader(reply)
switch {
case err != nil:
return 0, fmt.Errorf("parsing reply TCP header: %w", err)
case replyHeader.typ != packetTypeSYNACK:
return 0, fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, replyHeader.typ)
case replyHeader.ack != synSeq+1:
return 0, fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, replyHeader.ack)
case replyHeader.options.mss == 0:
return 0, fmt.Errorf("%w: MSS option not found in reply", errMSSNotFound)
}
err = sendRST(fd, src, dst, replyHeader.ack)
if err != nil {
return 0, fmt.Errorf("sending RST packet: %w", err)
}
return replyHeader.options.mss, nil
}
+59
View File
@@ -0,0 +1,59 @@
//go:build linux
package tcp
import (
"context"
"net/netip"
"testing"
"time"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_findHighestMSSDestination(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopLogger{})
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
dsts := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
const timeout = time.Second
fw := getFirewall(t)
logger := &noopLogger{}
dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts,
excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger)
require.NoError(t, err, "finding highest MSS destination")
assert.Contains(t, dsts, dst, "destination should be in the provided list")
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
assert.LessOrEqual(t, mss, constants.MaxEthernetFrameSize,
"MSS should be less than or equal to the maximum Ethernet frame size ")
}
+55 -14
View File
@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/pmtud/test" "github.com/qdm12/gluetun/internal/pmtud/test"
) )
@@ -18,22 +19,31 @@ type testUnit struct {
ok bool ok bool
} }
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort, const excludeMark = 4545
// PathMTUDiscover first finds the destination TCP server with the highest
// available MSS, in order to be able to test the highest possible MTU.
// If a server has an MSS larger than maxPossibleMTU, this one is used.
// It then performs a binary search of the MTU between minMTU and maxPossibleMTU,
// by sending IP packets with the Don't Fragment bit set and checking if they
// are received or not, exploiting the stateful nature of TCP to be able to
// correlate replies to the sent packets.
// Note all dsts must be of the same IP family (all IPv4 or all IPv6).
func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration, minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger, firewall Firewall, logger Logger,
) (mtu uint32, err error) { ) (mtu uint32, err error) {
family := constants.AF_INET family := constants.AF_INET
if dst.Addr().Is6() { if dsts[0].Addr().Is6() {
family = constants.AF_INET6 family = constants.AF_INET6
} }
const excludeMark = 4325
fd, stop, err := startRawSocket(family, excludeMark) fd, stop, err := startRawSocket(family, excludeMark)
if err != nil { if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err) return 0, fmt.Errorf("starting raw socket: %w", err)
} }
defer stop() defer stop()
tracker := newTracker(fd, dst.Addr().Is4()) tracker := newTracker(fd, family == constants.AF_INET)
trackerCtx, trackerCancel := context.WithCancel(ctx) trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel() defer trackerCancel()
@@ -42,28 +52,59 @@ func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
trackerErrCh <- tracker.listen(trackerCtx) trackerErrCh <- tracker.listen(trackerCtx)
}() }()
pmtudCtx, pmtudCancel := context.WithCancel(ctx) type mssResult struct {
defer pmtudCancel() dst netip.AddrPort
type result struct { mss uint32
err error
}
mssResultCh := make(chan mssResult)
ctx, cancel := context.WithTimeout(ctx, tryTimeout)
defer cancel()
go func() {
dst, mss, err := findHighestMSSDestination(ctx, fd, dsts, excludeMark,
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
}()
var highestMSSDst netip.AddrPort
select {
case err = <-trackerErrCh:
cancel()
<-mssResultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case result := <-mssResultCh:
if result.err != nil {
trackerCancel()
<-trackerErrCh
return 0, fmt.Errorf("finding MSS: %w", result.err)
}
highestMSSDst = result.dst
ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4())
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
}
type pmtudResult struct {
mtu uint32 mtu uint32
err error err error
} }
pmtudResultCh := make(chan result) resultCh := make(chan pmtudResult)
ctx, cancel = context.WithCancel(ctx)
defer cancel()
go func() { go func() {
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU, mtu, err := pathMTUDiscover(ctx, fd, highestMSSDst, minMTU, maxPossibleMTU,
excludeMark, tryTimeout, tracker, firewall, logger) excludeMark, tryTimeout, tracker, firewall, logger)
pmtudResultCh <- result{mtu: mtu, err: err} resultCh <- pmtudResult{mtu: mtu, err: err}
}() }()
select { select {
case err = <-trackerErrCh: case err = <-trackerErrCh:
pmtudCancel() cancel()
<-pmtudResultCh <-resultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err) return 0, fmt.Errorf("listening for TCP replies: %w", err)
case res := <-pmtudResultCh: case result := <-resultCh:
trackerCancel() trackerCancel()
<-trackerErrCh <-trackerErrCh
return res.mtu, res.err return result.mtu, result.err
} }
} }
+2 -2
View File
@@ -129,8 +129,8 @@ func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
if firstReplyHeader.options.mss != 0 { if firstReplyHeader.options.mss != 0 {
// If the server sent an MSS option, make sure our test packet is not larger than that MSS. // If the server sent an MSS option, make sure our test packet is not larger than that MSS.
tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength tcpDataLength := getPayloadLength(mtu, dst) - constants.BaseTCPHeaderLength
if tcpDataLength > uint32(firstReplyHeader.options.mss) { if tcpDataLength > firstReplyHeader.options.mss {
diff := tcpDataLength - uint32(firstReplyHeader.options.mss) diff := tcpDataLength - firstReplyHeader.options.mss
minMTU := constants.MinIPv4MTU minMTU := constants.MinIPv4MTU
if dst.Addr().Is6() { if dst.Addr().Is6() {
minMTU = constants.MinIPv6MTU minMTU = constants.MinIPv6MTU
+8 -3
View File
@@ -27,12 +27,17 @@ func Test_PathMTUDiscover(t *testing.T) {
} }
require.NoError(t, err, "creating firewall config") require.NoError(t, err, "creating firewall config")
dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) dsts := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 53),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
const minMTU = constants.MinIPv6MTU const minMTU = constants.MinIPv6MTU
const maxMTU = constants.MaxEthernetFrameSize const maxMTU = constants.MaxEthernetFrameSize
const tryTimeout = time.Second const tryTimeout = time.Second
mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger) mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger)
require.NoError(t, err, "discovering path MTU") require.NoError(t, err, "discovering path MTU")
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0") assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
t.Logf("discovered path MTU to %s is %d", dst, mtu) t.Logf("discovered path MTU is %d", mtu)
} }
+1 -11
View File
@@ -4,14 +4,11 @@ package tcp
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/ip"
@@ -26,13 +23,6 @@ func Test_runTest(t *testing.T) {
noopLogger := &noopLogger{} noopLogger := &noopLogger{}
cmder := command.New()
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
netlinker := netlink.New(noopLogger) netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker) loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU") require.NoError(t, err, "finding loopback IPv4 MTU")
@@ -42,7 +32,6 @@ func Test_runTest(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET const family = constants.AF_INET
const excludeMark = 4545
fd, stop, err := startRawSocket(family, excludeMark) fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err) require.NoError(t, err)
@@ -116,6 +105,7 @@ func Test_runTest(t *testing.T) {
require.NoError(t, err, "getting source address to reach remote server %s", dst) require.NoError(t, err, "getting source address to reach remote server %s", dst)
t.Cleanup(cleanup) t.Cleanup(cleanup)
fw := getFirewall(t)
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark) revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
+2 -2
View File
@@ -199,7 +199,7 @@ func parseTCPHeader(b []byte) (header tcpHeader, err error) {
} }
type options struct { type options struct {
mss uint16 mss uint32
windowScale *uint8 // Pointer to differentiate between 0 and "not present" windowScale *uint8 // Pointer to differentiate between 0 and "not present"
sackPermitted bool sackPermitted bool
timestamps *optionTimestamps timestamps *optionTimestamps
@@ -266,7 +266,7 @@ func parseTCPOptions(b []byte) (parsed options, err error) {
return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d", return options{}, fmt.Errorf("%w: MSS option at offset %d has length %d, expected %d",
ErrTCPOptionMSSInvalid, i, length, expectedLength) ErrTCPOptionMSSInvalid, i, length, expectedLength)
} }
parsed.mss = binary.BigEndian.Uint16(data) parsed.mss = uint32(binary.BigEndian.Uint16(data))
case optionTypeWindowScale: case optionTypeWindowScale:
const expectedLength = 3 const expectedLength = 3
if length != expectedLength { if length != expectedLength {