diff --git a/Dockerfile b/Dockerfile index 7fa29518..72d24eee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -114,7 +114,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_IMPLEMENTATION=auto \ # PMTUD PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \ - PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53 \ + PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \ # VPN server filtering SERVER_REGIONS= \ SERVER_COUNTRIES= \ diff --git a/internal/configuration/settings/pmtud.go b/internal/configuration/settings/pmtud.go index 6255c893..e447aa70 100644 --- a/internal/configuration/settings/pmtud.go +++ b/internal/configuration/settings/pmtud.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "strings" "github.com/qdm12/gosettings" "github.com/qdm12/gosettings/reader" @@ -70,6 +69,10 @@ func (p *PMTUD) setDefaults() { netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort), netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort), + netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort), + netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort), + netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort), + netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort), } p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses) } @@ -81,17 +84,15 @@ func (p PMTUD) String() string { func (p PMTUD) toLinesNode() (node *gotree.Node) { node = gotree.New("Path MTU discovery:") - addrs := make([]string, len(p.ICMPAddresses)) - for i, addr := range p.ICMPAddresses { - addrs[i] = addr.String() + icmpAddrNode := node.Append("ICMP addresses:") + for _, addr := range p.ICMPAddresses { + icmpAddrNode.Append(addr.String()) } - node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", ")) - addrs = make([]string, len(p.TCPAddresses)) - for i, addr := range p.TCPAddresses { - addrs[i] = addr.String() + tcpAddrNode := node.Append("TCP addresses:") + for _, addr := range p.TCPAddresses { + tcpAddrNode.Append(addr.String()) } - node.Appendf("TCP addresses: %s", strings.Join(addrs, ", ")) return node } diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index c0976cd9..73c3cb28 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -38,8 +38,18 @@ func Test_Settings_String(t *testing.T) { | | ├── Run OpenVPN as: root | | └── Verbosity level: 1 | └── Path MTU discovery: -| ├── ICMP addresses: 1.1.1.1, 8.8.8.8 -| └── TCP addresses: 1.1.1.1:53, 8.8.8.8:53, 1.1.1.1:443, 8.8.8.8:443 +| ├── ICMP addresses: +| | ├── 1.1.1.1 +| | └── 8.8.8.8 +| └── TCP addresses: +| ├── 1.1.1.1:53 +| ├── 8.8.8.8:53 +| ├── 1.1.1.1:443 +| ├── 8.8.8.8:443 +| ├── [2606:4700:4700::1111]:53 +| ├── [2001:4860:4860::8888]:53 +| ├── [2606:4700:4700::1111]:443 +| └── [2001:4860:4860::8888]:443 ├── DNS settings: | ├── Keep existing nameserver(s): no | ├── DNS server address to use: 127.0.0.1 diff --git a/internal/pmtud/ip/family.go b/internal/pmtud/ip/family.go new file mode 100644 index 00000000..a40e1f13 --- /dev/null +++ b/internal/pmtud/ip/family.go @@ -0,0 +1,27 @@ +package ip + +import ( + "net/netip" + "slices" + + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +func GetFamilies(dsts []netip.AddrPort) (families []int) { + const maxFamilies = 2 + families = make([]int, 0, maxFamilies) + for _, dst := range dsts { + family := GetFamily(dst) + if !slices.Contains(families, family) { + families = append(families, family) + } + } + return families +} + +func GetFamily(dst netip.AddrPort) int { + if dst.Addr().Is4() { + return constants.AF_INET + } + return constants.AF_INET6 +} diff --git a/internal/pmtud/tcp/helpers_test.go b/internal/pmtud/tcp/helpers_test.go index 80b21614..7eeb0c8b 100644 --- a/internal/pmtud/tcp/helpers_test.go +++ b/internal/pmtud/tcp/helpers_test.go @@ -78,24 +78,30 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) } -func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { +func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { noopLogger := &noopLogger{} routing := routing.New(netlinker, noopLogger) defaultRoutes, err := routing.DefaultRoutes() if err != nil { return 0, fmt.Errorf("getting default routes: %w", err) } - for _, route := range defaultRoutes { - if route.Family != netlink.FamilyV4 { - continue + families := []uint8{constants.AF_INET, constants.AF_INET6} + for _, family := range families { + for _, route := range defaultRoutes { + if route.Family != family { + continue + } + link, err := netlinker.LinkByName(route.NetInterface) + if err != nil { + return 0, fmt.Errorf("getting link by name: %w", err) + } + mtu = max(mtu, link.MTU) } - link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface) - if err != nil { - return 0, fmt.Errorf("getting link by name: %w", err) - } - return link.MTU, nil } - return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) + if mtu == 0 { + return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) + } + return mtu, nil } func reserveClosedPort(t *testing.T) (port uint16) { diff --git a/internal/pmtud/tcp/mss.go b/internal/pmtud/tcp/mss.go index 518403b4..dbfd6897 100644 --- a/internal/pmtud/tcp/mss.go +++ b/internal/pmtud/tcp/mss.go @@ -14,7 +14,7 @@ import ( // findHighestMSSDestination finds the destination with the highest // MSS amongst the provided destinations. -func findHighestMSSDestination(ctx context.Context, fd fileDescriptor, +func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor, dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32, timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) ( dst netip.AddrPort, mss uint32, err error, @@ -30,6 +30,7 @@ func findHighestMSSDestination(ctx context.Context, fd fileDescriptor, defer cancel() for _, dst := range dsts { go func(dst netip.AddrPort) { + fd := familyToFD[ip.GetFamily(dst)] mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger) resultCh <- result{dst: dst, mss: mss, err: err} }(dst) diff --git a/internal/pmtud/tcp/mss_test.go b/internal/pmtud/tcp/mss_test.go index b92c8926..181b772f 100644 --- a/internal/pmtud/tcp/mss_test.go +++ b/internal/pmtud/tcp/mss_test.go @@ -18,17 +18,16 @@ func Test_findHighestMSSDestination(t *testing.T) { t.Parallel() netlinker := netlink.New(&noopLogger{}) - defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) - require.NoError(t, err, "finding default IPv4 route MTU") + defaultMTU, err := findDefaultRouteMTU(netlinker) + require.NoError(t, err, "finding default route MTU") ctx, cancel := context.WithCancel(t.Context()) - const family = constants.AF_INET - fd, stop, err := startRawSocket(family, excludeMark) + families := []int{constants.AF_INET, constants.AF_INET6} + familyToFD, stop, err := startRawSockets(families, excludeMark) require.NoError(t, err) - const ipv4 = true - tracker := newTracker(fd, ipv4) + tracker := newTracker(familyToFD) trackerCh := make(chan error) go func() { trackerCh <- tracker.listen(ctx) @@ -44,13 +43,15 @@ func Test_findHighestMSSDestination(t *testing.T) { dsts := []netip.AddrPort{ netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), + netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443), + netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443), } const timeout = time.Second fw := getFirewall(t) logger := &noopLogger{} - dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts, - excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger) + dst, mss, err := findHighestMSSDestination(t.Context(), familyToFD, dsts, + excludeMark, defaultMTU, timeout, tracker, fw, logger) require.NoError(t, err, "finding highest MSS destination") assert.Contains(t, dsts, dst, "destination should be in the provided list") assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000") diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go index 764933b1..ceff4b06 100644 --- a/internal/pmtud/tcp/multi.go +++ b/internal/pmtud/tcp/multi.go @@ -33,17 +33,14 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, minMTU, maxPossibleMTU uint32, tryTimeout time.Duration, firewall Firewall, logger Logger, ) (mtu uint32, err error) { - family := constants.AF_INET - if dsts[0].Addr().Is6() { - family = constants.AF_INET6 - } - fd, stop, err := startRawSocket(family, excludeMark) + families := ip.GetFamilies(dsts) + familyToFD, stop, err := startRawSockets(families, excludeMark) if err != nil { - return 0, fmt.Errorf("starting raw socket: %w", err) + return 0, fmt.Errorf("starting raw sockets: %w", err) } defer stop() - tracker := newTracker(fd, family == constants.AF_INET) + tracker := newTracker(familyToFD) trackerCtx, trackerCancel := context.WithCancel(ctx) defer trackerCancel() @@ -62,7 +59,7 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout) defer mssCancel() go func() { - dst, mss, err := findHighestMSSDestination(mssCtx, fd, dsts, excludeMark, + dst, mss, err := findHighestMSSDestination(mssCtx, familyToFD, dsts, excludeMark, maxPossibleMTU, tryTimeout, tracker, firewall, logger) mssResultCh <- mssResult{dst: dst, mss: mss, err: err} }() @@ -83,6 +80,8 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort, maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss } + fd := familyToFD[ip.GetFamily(highestMSSDst)] + type pmtudResult struct { mtu uint32 err error diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go index 0bc00cf8..03931f98 100644 --- a/internal/pmtud/tcp/tcp.go +++ b/internal/pmtud/tcp/tcp.go @@ -10,6 +10,29 @@ import ( "github.com/qdm12/gluetun/internal/pmtud/ip" ) +func startRawSockets(families []int, excludeMark int) (familyToSocket map[int]fileDescriptor, stop func(), err error) { + familyToSocket = make(map[int]fileDescriptor, len(families)) + stops := make([]func(), 0, len(families)) + for _, family := range families { + fd, stop, err := startRawSocket(family, excludeMark) + if err != nil { + for _, stop := range stops { + stop() + } + return nil, nil, fmt.Errorf("starting raw socket for family %d: %w", family, err) + } + stops = append(stops, stop) + familyToSocket[family] = fd + } + + stop = func() { + for _, stop := range stops { + stop() + } + } + return familyToSocket, stop, nil +} + func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) { fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP) if err != nil { @@ -33,10 +56,10 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er } // Allow sending packets larger than cached PMTU (for PMTUD probing) - err = setMTUDiscovery(fdPlatform) + err = setMTUDiscovery(fdPlatform, family == constants.AF_INET) if err != nil { _ = closeSocket(fdPlatform) - return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err) + return 0, nil, fmt.Errorf("setting MTU discovery options: %w", err) } // use polling because some Linux systems do not cancel diff --git a/internal/pmtud/tcp/tcp_integration_test.go b/internal/pmtud/tcp/tcp_integration_test.go index d3374ec0..9bdebdfe 100644 --- a/internal/pmtud/tcp/tcp_integration_test.go +++ b/internal/pmtud/tcp/tcp_integration_test.go @@ -18,10 +18,21 @@ import ( func Test_PathMTUDiscover(t *testing.T) { t.Parallel() - noopLogger := log.New(log.SetLevel(log.LevelDebug)) + + const tryTimeout = time.Second + deadline, ok := t.Deadline() + if ok { + timeLeft := time.Until(deadline) + const maxTimeNeeded = tryTimeout * 4 // MSS discovery + 3 MTU tries + require.GreaterOrEqual(t, timeLeft, maxTimeNeeded, + "not enough time remaining for TCP PMTUD test, need %s and got %s", + maxTimeNeeded, timeLeft) + } + + logger := log.New(log.SetLevel(log.LevelDebug)) cmder := command.New() - fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) + fw, err := firewall.NewConfig(t.Context(), logger, cmder, nil, nil) if errors.Is(err, firewall.ErrIPTablesNotSupported) { t.Skip("iptables not installed, skipping TCP PMTUD tests") } @@ -32,11 +43,12 @@ func Test_PathMTUDiscover(t *testing.T) { netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), + netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443), + netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443), } const minMTU = constants.MinIPv6MTU const maxMTU = constants.MaxEthernetFrameSize - const tryTimeout = time.Second - mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger) + mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, logger) require.NoError(t, err, "discovering path MTU") assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0") t.Logf("discovered path MTU is %d", mtu) diff --git a/internal/pmtud/tcp/tcp_linux.go b/internal/pmtud/tcp/tcp_linux.go index b751fe1e..86ad33c4 100644 --- a/internal/pmtud/tcp/tcp_linux.go +++ b/internal/pmtud/tcp/tcp_linux.go @@ -13,6 +13,9 @@ func setMark(fd, excludeMark int) error { return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark) } -func setMTUDiscovery(fd int) error { - return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) +func setMTUDiscovery(fd int, ipv4 bool) error { + if ipv4 { + return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) + } + return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE) } diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go index d644bd39..3099d6d3 100644 --- a/internal/pmtud/tcp/tcp_test.go +++ b/internal/pmtud/tcp/tcp_test.go @@ -26,17 +26,15 @@ func Test_runTest(t *testing.T) { netlinker := netlink.New(noopLogger) loopbackMTU, err := findLoopbackMTU(netlinker) require.NoError(t, err, "finding loopback IPv4 MTU") - defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) - require.NoError(t, err, "finding default IPv4 route MTU") + defaultMTU, err := findDefaultRouteMTU(netlinker) + require.NoError(t, err, "finding default route MTU") ctx, cancel := context.WithCancel(t.Context()) - const family = constants.AF_INET - fd, stop, err := startRawSocket(family, excludeMark) + familyToFD, stop, err := startRawSockets([]int{constants.AF_INET, constants.AF_INET6}, excludeMark) require.NoError(t, err) - const ipv4 = true - tracker := newTracker(fd, ipv4) + tracker := newTracker(familyToFD) trackerCh := make(chan error) go func() { trackerCh <- tracker.listen(ctx) @@ -71,24 +69,24 @@ func Test_runTest(t *testing.T) { "remote_not_listening": { timeout: 50 * time.Millisecond, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345), - mtu: defaultIPv4MTU - mtuSafetyBuffer, + mtu: defaultMTU - mtuSafetyBuffer, }, "1.1.1.1:443": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), - mtu: defaultIPv4MTU - mtuSafetyBuffer, + mtu: defaultMTU - mtuSafetyBuffer, success: true, }, "1.1.1.1:80": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), - mtu: defaultIPv4MTU - mtuSafetyBuffer, + mtu: defaultMTU - mtuSafetyBuffer, success: true, }, "8.8.8.8:443": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), - mtu: defaultIPv4MTU - mtuSafetyBuffer, + mtu: defaultMTU - mtuSafetyBuffer, success: true, }, } @@ -99,6 +97,7 @@ func Test_runTest(t *testing.T) { ctrl := gomock.NewController(t) dst := testCase.server + fd := familyToFD[ip.GetFamily(dst)] const proto = constants.IPPROTO_TCP src, cleanup, err := ip.SrcAddr(dst, proto) diff --git a/internal/pmtud/tcp/tcp_unspecified.go b/internal/pmtud/tcp/tcp_unspecified.go index 8943e9d1..9e90ca9d 100644 --- a/internal/pmtud/tcp/tcp_unspecified.go +++ b/internal/pmtud/tcp/tcp_unspecified.go @@ -6,6 +6,6 @@ func setMark(fd, excludeMark int) error { panic("not implemented") } -func setMTUDiscovery(fd int) error { +func setMTUDiscovery(fd int, ipv4 bool) error { panic("not implemented") } diff --git a/internal/pmtud/tcp/tcp_windows.go b/internal/pmtud/tcp/tcp_windows.go index eec7ff8f..23ae8e29 100644 --- a/internal/pmtud/tcp/tcp_windows.go +++ b/internal/pmtud/tcp/tcp_windows.go @@ -35,7 +35,7 @@ func setMark(fd windows.Handle, _ int) error { panic("not implemented") } -func setMTUDiscovery(fd windows.Handle) error { +func setMTUDiscovery(fd windows.Handle, ipv4 bool) error { panic("not implemented") } diff --git a/internal/pmtud/tcp/tracker.go b/internal/pmtud/tcp/tracker.go index a178b7f0..e83d7a1d 100644 --- a/internal/pmtud/tcp/tracker.go +++ b/internal/pmtud/tcp/tracker.go @@ -12,8 +12,7 @@ import ( ) type tracker struct { - fd fileDescriptor - ipv4 bool + familyToFD map[int]fileDescriptor mutex sync.RWMutex portsToDispatch map[uint32]dispatch } @@ -23,10 +22,9 @@ type dispatch struct { abort <-chan struct{} } -func newTracker(fd fileDescriptor, ipv4 bool) *tracker { +func newTracker(familyToFD map[int]fileDescriptor) *tracker { return &tracker{ - fd: fd, - ipv4: ipv4, + familyToFD: familyToFD, portsToDispatch: make(map[uint32]dispatch), } } @@ -57,11 +55,36 @@ func (t *tracker) unregister(localPort, remotePort uint16) { delete(t.portsToDispatch, key) } -// listen listens for incoming TCP packets and dispatches them to the -// correct channel based on the source and destination port. +func (t *tracker) listen(ctx context.Context) (err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + type result struct { + family int + err error + } + resultCh := make(chan result) + for family, fd := range t.familyToFD { + go func(family int, fd fileDescriptor) { + err := t.listenFD(ctx, fd, family == constants.AF_INET) + resultCh <- result{family: family, err: err} + }(family, fd) + } + + for range t.familyToFD { + result := <-resultCh + if err == nil && result.err != nil { + cancel() // stop the other listener if it is still running + err = fmt.Errorf("listening for family %d: %w", result.family, result.err) + } + } + return err +} + +// listenFD listens for incoming TCP packets on the given file descriptor, +// and dispatches them to the correct channel based on the source and destination port. // If the context has a deadline associated, this one is used on the socket. // Note it returns a nil error on context cancellation. -func (t *tracker) listen(ctx context.Context) error { +func (t *tracker) listenFD(ctx context.Context, fd fileDescriptor, ipv4 bool) error { deadline, hasDeadline := ctx.Deadline() for ctx.Err() == nil { if hasDeadline { @@ -69,14 +92,14 @@ func (t *tracker) listen(ctx context.Context) error { if remaining <= 0 { return nil } - err := setSocketTimeout(t.fd, remaining) + err := setSocketTimeout(fd, remaining) if err != nil { return fmt.Errorf("setting socket receive timeout: %w", err) } } reply := make([]byte, constants.MaxEthernetFrameSize) - n, _, err := recvFrom(t.fd, reply, 0) + n, _, err := recvFrom(fd, reply, 0) if err != nil { switch { case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK): @@ -91,7 +114,7 @@ func (t *tracker) listen(ctx context.Context) error { } reply = reply[:n] - if t.ipv4 { + if ipv4 { var ok bool reply, ok = stripIPv4Header(reply) if !ok {