Files
gluetun/internal/pmtud/tcp/multi.go
T
Quentin McGaw c6b211ef9b feat(pmtud/tcp): support mixed IPv4 and IPv6 TCP servers
- Add default cloudflare and google tls ipv6 servers to default tcp servers
- update integration test to try against both ipv4 and ipv6 servers
2026-02-19 17:11:16 +00:00

176 lines
4.8 KiB
Go

package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
var ErrMTUNotFound = errors.New("MTU not found")
type testUnit struct {
mtu uint32
ok bool
}
const excludeMark = 4545
// PathMTUDiscover first finds the destination TCP server with the highest
// available MSS, in order to be able to test the highest possible MTU.
// If a server has an MSS larger than maxPossibleMTU, this one is used.
// It then performs a binary search of the MTU between minMTU and maxPossibleMTU,
// by sending IP packets with the Don't Fragment bit set and checking if they
// are received or not, exploiting the stateful nature of TCP to be able to
// correlate replies to the sent packets.
// Note all dsts must be of the same IP family (all IPv4 or all IPv6).
func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger,
) (mtu uint32, err error) {
families := ip.GetFamilies(dsts)
familyToFD, stop, err := startRawSockets(families, excludeMark)
if err != nil {
return 0, fmt.Errorf("starting raw sockets: %w", err)
}
defer stop()
tracker := newTracker(familyToFD)
trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel()
trackerErrCh := make(chan error)
go func() {
trackerErrCh <- tracker.listen(trackerCtx)
}()
type mssResult struct {
dst netip.AddrPort
mss uint32
err error
}
mssResultCh := make(chan mssResult)
mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout)
defer mssCancel()
go func() {
dst, mss, err := findHighestMSSDestination(mssCtx, familyToFD, dsts, excludeMark,
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
}()
var highestMSSDst netip.AddrPort
select {
case err = <-trackerErrCh:
mssCancel()
<-mssResultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case result := <-mssResultCh:
if result.err != nil {
trackerCancel()
<-trackerErrCh
return 0, fmt.Errorf("finding MSS: %w", result.err)
}
highestMSSDst = result.dst
ipHeaderLength := ip.HeaderLength(highestMSSDst.Addr().Is4())
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
}
fd := familyToFD[ip.GetFamily(highestMSSDst)]
type pmtudResult struct {
mtu uint32
err error
}
resultCh := make(chan pmtudResult)
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
defer pmtudCancel()
go func() {
mtu, err := pathMTUDiscover(pmtudCtx, fd, highestMSSDst, minMTU, maxPossibleMTU,
excludeMark, tryTimeout, tracker, firewall, logger)
resultCh <- pmtudResult{mtu: mtu, err: err}
}()
select {
case err = <-trackerErrCh:
pmtudCancel()
<-resultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case result := <-resultCh:
trackerCancel()
<-trackerErrCh
return result.mtu, result.err
}
}
var errTimedOut = errors.New("timed out")
func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
dst netip.AddrPort, minMTU, maxPossibleMTU uint32, excludeMark int,
tryTimeout time.Duration, tracker *tracker, firewall Firewall,
logger Logger,
) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
tests := make([]testUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = testUnit{mtu: mtusToTest[i]}
}
errCause := fmt.Errorf("%w: after %s", errTimedOut, tryTimeout)
runCtx, runCancel := context.WithTimeoutCause(ctx, tryTimeout, errCause)
defer runCancel()
doneCh := make(chan struct{})
for i := range tests {
go func(i int) {
err := runTest(runCtx, dst, tests[i].mtu, excludeMark,
fd, tracker, firewall, logger)
tests[i].ok = err == nil
doneCh <- struct{}{}
}(i)
}
i := 0
for i < len(tests) {
select {
case <-runCtx.Done(): // timeout or parent context canceled
err = context.Cause(runCtx)
// collect remaining done signals
for i < len(tests) {
<-doneCh
i++
}
case <-doneCh:
i++
}
}
if err != nil && !errors.Is(err, errTimedOut) {
// context is canceled but did not timeout after tryTimeout
return 0, fmt.Errorf("running MTU tests: %w", err)
}
if tests[len(tests)-1].ok {
return tests[len(tests)-1].mtu, nil
}
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
runCancel() // just to release resources although runCtx is no longer used
return pathMTUDiscover(ctx, fd, dst,
tests[i].mtu, tests[i+1].mtu-1, excludeMark,
tryTimeout, tracker, firewall, logger)
}
}
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
}