mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
04d7cef294
- this makes PMTUD TCP reliable - this only works on kernels with the mark module - on kernels without the mark module, the icmp pmtud mtu found is used
108 lines
2.5 KiB
Go
108 lines
2.5 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
|
"github.com/qdm12/gluetun/internal/pmtud/test"
|
|
)
|
|
|
|
var ErrMTUNotFound = errors.New("MTU not found")
|
|
|
|
type testUnit struct {
|
|
mtu uint32
|
|
ok bool
|
|
}
|
|
|
|
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
|
|
) (mtu uint32, err error) {
|
|
const excludeMark = 4325
|
|
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
|
}
|
|
defer func() {
|
|
err := revert(ctx)
|
|
if err != nil {
|
|
logger.Warnf("reverting firewall changes: %s", err)
|
|
}
|
|
}()
|
|
|
|
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
|
|
}
|
|
|
|
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
|
|
) (mtu uint32, err error) {
|
|
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
|
return minMTU, nil
|
|
}
|
|
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
|
|
|
|
tests := make([]testUnit, len(mtusToTest))
|
|
for i := range mtusToTest {
|
|
tests[i] = testUnit{mtu: mtusToTest[i]}
|
|
}
|
|
|
|
family := constants.AF_INET
|
|
if addrPort.Addr().Is6() {
|
|
family = constants.AF_INET6
|
|
}
|
|
fd, stop, err := startRawSocket(family, excludeMark)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("starting raw socket: %w", err)
|
|
}
|
|
defer stop()
|
|
|
|
tracker := newTracker(fd, addrPort.Addr().Is4())
|
|
|
|
const timeout = time.Second
|
|
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- tracker.listen(runCtx)
|
|
}()
|
|
|
|
doneCh := make(chan struct{})
|
|
for i := range tests {
|
|
go func(i int) {
|
|
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
|
|
tests[i].ok = err == nil
|
|
doneCh <- struct{}{}
|
|
}(i)
|
|
}
|
|
|
|
for range tests {
|
|
select {
|
|
case <-doneCh:
|
|
case err := <-errCh:
|
|
if err == nil { // timeout
|
|
break
|
|
}
|
|
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
|
}
|
|
}
|
|
|
|
if tests[len(tests)-1].ok {
|
|
return tests[len(tests)-1].mtu, nil
|
|
}
|
|
|
|
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
|
if tests[i].ok {
|
|
stop()
|
|
cancel()
|
|
return pathMTUDiscover(ctx, addrPort,
|
|
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
|
|
}
|
|
}
|
|
|
|
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
|
}
|