feat(pmtud/tcp): use the TCP server with highest MSS to run MTU tests

This commit is contained in:
Quentin McGaw
2026-02-19 14:03:46 +00:00
parent fb85ae79d1
commit 8d86470905
10 changed files with 323 additions and 59 deletions
+55 -14
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
@@ -18,22 +19,31 @@ type testUnit struct {
ok bool
}
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
const excludeMark = 4545
// PathMTUDiscover first finds the destination TCP server with the highest
// available MSS, in order to be able to test the highest possible MTU.
// If a server has an MSS larger than maxPossibleMTU, this one is used.
// It then performs a binary search of the MTU between minMTU and maxPossibleMTU,
// by sending IP packets with the Don't Fragment bit set and checking if they
// are received or not, exploiting the stateful nature of TCP to be able to
// correlate replies to the sent packets.
// Note all dsts must be of the same IP family (all IPv4 or all IPv6).
func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger,
) (mtu uint32, err error) {
family := constants.AF_INET
if dst.Addr().Is6() {
if dsts[0].Addr().Is6() {
family = constants.AF_INET6
}
const excludeMark = 4325
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, dst.Addr().Is4())
tracker := newTracker(fd, family == constants.AF_INET)
trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel()
@@ -42,28 +52,59 @@ func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
trackerErrCh <- tracker.listen(trackerCtx)
}()
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
defer pmtudCancel()
type result struct {
type mssResult struct {
dst netip.AddrPort
mss uint32
err error
}
mssResultCh := make(chan mssResult)
ctx, cancel := context.WithTimeout(ctx, tryTimeout)
defer cancel()
go func() {
dst, mss, err := findHighestMSSDestination(ctx, fd, dsts, excludeMark,
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
}()
var highestMSSDst netip.AddrPort
select {
case err = <-trackerErrCh:
cancel()
<-mssResultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case result := <-mssResultCh:
if result.err != nil {
trackerCancel()
<-trackerErrCh
return 0, fmt.Errorf("finding MSS: %w", result.err)
}
highestMSSDst = result.dst
ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4())
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
}
type pmtudResult struct {
mtu uint32
err error
}
pmtudResultCh := make(chan result)
resultCh := make(chan pmtudResult)
ctx, cancel = context.WithCancel(ctx)
defer cancel()
go func() {
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU,
mtu, err := pathMTUDiscover(ctx, fd, highestMSSDst, minMTU, maxPossibleMTU,
excludeMark, tryTimeout, tracker, firewall, logger)
pmtudResultCh <- result{mtu: mtu, err: err}
resultCh <- pmtudResult{mtu: mtu, err: err}
}()
select {
case err = <-trackerErrCh:
pmtudCancel()
<-pmtudResultCh
cancel()
<-resultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case res := <-pmtudResultCh:
case result := <-resultCh:
trackerCancel()
<-trackerErrCh
return res.mtu, res.err
return result.mtu, result.err
}
}