diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index edf07ca6..f06d882b 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -66,7 +66,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net minMTU = constants.MinIPv6MTU } if icmpSuccess { - const mtuMargin = 300 + const mtuMargin = 150 minMTU = max(maxPossibleMTU-mtuMargin, minMTU) } mtu, err = tcp.PathMTUDiscover(ctx, tcpAddrs, minMTU, maxPossibleMTU, tryTimeout, fw, logger) diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go index ceff4b06..f4f6b8a2 100644 --- a/internal/pmtud/tcp/multi.go +++ b/internal/pmtud/tcp/multi.go @@ -12,7 +12,10 @@ import ( "github.com/qdm12/gluetun/internal/pmtud/test" ) -var ErrMTUNotFound = errors.New("MTU not found") +var ( + ErrMTUNotFound = errors.New("MTU not found") + ErrMSSTooSmall = errors.New("TCP MSS is too small to find the MTU") +) type testUnit struct { mtu uint32 @@ -63,24 +66,28 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, maxPossibleMTU, tryTimeout, tracker, firewall, logger) mssResultCh <- mssResult{dst: dst, mss: mss, err: err} }() - var highestMSSDst netip.AddrPort + var result mssResult select { case err = <-trackerErrCh: mssCancel() <-mssResultCh return 0, fmt.Errorf("listening for TCP replies: %w", err) - case result := <-mssResultCh: - if result.err != nil { - trackerCancel() - <-trackerErrCh - return 0, fmt.Errorf("finding MSS: %w", result.err) - } - highestMSSDst = result.dst - ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4()) - maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss + case result = <-mssResultCh: + } + if result.err != nil { + trackerCancel() + <-trackerErrCh + return 0, fmt.Errorf("finding MSS: %w", result.err) + } + ipHeaderLength := ip.HeaderLength(result.dst.Addr().Is4()) + maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss + if minMTU > maxPossibleMTU { + // Occasionally, the MSS is a lot smaller than the MTU found using ICMP + const safetyBuffer = 100 + minMTU = maxPossibleMTU - safetyBuffer } - fd := familyToFD[ip.GetFamily(highestMSSDst)] + fd := familyToFD[ip.GetFamily(result.dst)] type pmtudResult struct { mtu uint32 @@ -90,7 +97,7 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, pmtudCtx, pmtudCancel := context.WithCancel(ctx) defer pmtudCancel() go func() { - mtu, err := pathMTUDiscover(pmtudCtx, fd, highestMSSDst, minMTU, maxPossibleMTU, + mtu, err := pathMTUDiscover(pmtudCtx, fd, result.dst, minMTU, maxPossibleMTU, excludeMark, tryTimeout, tracker, firewall, logger) resultCh <- pmtudResult{mtu: mtu, err: err} }()