From 069cde8a855d988558e3b73a686fdeffe066462b Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sun, 8 Mar 2026 23:27:04 +0000 Subject: [PATCH] hotfix(pmtud): set mss on all VPN routes - fix behavior for OpenVPN splitting default route in multiple routes - fix behavior for Wireguard if user specifies AllowedIPs --- internal/routing/vpn.go | 25 +++++++++++++++---------- internal/vpn/interfaces.go | 2 +- internal/vpn/tunnelup.go | 30 ++++++++++++++++++------------ 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/internal/routing/vpn.go b/internal/routing/vpn.go index 47e795e1..e1e947da 100644 --- a/internal/routing/vpn.go +++ b/internal/routing/vpn.go @@ -50,23 +50,28 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { var ErrVPNRouteNotFound = errors.New("VPN route not found") -func (r *Routing) VPNRoute(vpnIntf string) (route netlink.Route, err error) { +func (r *Routing) VPNRoutes(vpnIntf string) (routes []netlink.Route, err error) { vpnLink, err := r.netLinker.LinkByName(vpnIntf) if err != nil { - return route, fmt.Errorf("finding link %s: %w", vpnIntf, err) + return nil, fmt.Errorf("finding link %s: %w", vpnIntf, err) } vpnLinkIndex := vpnLink.Index - routes, err := r.netLinker.RouteList(netlink.FamilyAll) + allRoutes, err := r.netLinker.RouteList(netlink.FamilyAll) if err != nil { - return route, fmt.Errorf("listing routes: %w", err) + return nil, fmt.Errorf("listing routes: %w", err) } - for _, route := range routes { - if route.LinkIndex == vpnLinkIndex && - !route.Dst.IsValid() { - return route, nil + routes = make([]netlink.Route, 0, len(allRoutes)) + for _, route := range allRoutes { + if route.LinkIndex == vpnLinkIndex { + routes = append(routes, route) } } - return route, fmt.Errorf("%w: for interface %s in %d routes", - ErrVPNRouteNotFound, vpnIntf, len(routes)) + + if len(routes) == 0 { + return nil, fmt.Errorf("%w: for interface %s in %d routes", + ErrVPNRouteNotFound, vpnIntf, len(allRoutes)) + } + + return routes, nil } diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 199ffd51..57d7b0d6 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -24,7 +24,7 @@ type Firewall interface { type Routing interface { VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) - VPNRoute(vpnIntf string) (route netlink.Route, err error) + VPNRoutes(vpnIntf string) (route []netlink.Route, err error) } type PortForward interface { diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index a1ecff84..186375a1 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -174,9 +174,9 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, return fmt.Errorf("getting VPN gateway IP address: %w", err) } - vpnRoute, err := routing.VPNRoute(vpnInterface) + vpnRoutes, err := routing.VPNRoutes(vpnInterface) if err != nil { - return fmt.Errorf("getting VPN route: %w", err) + return fmt.Errorf("getting VPN routes: %w", err) } link, err := netlinker.LinkByName(vpnInterface) @@ -208,7 +208,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) } - err = setTCPMSSOnVPNRoute(vpnLinkMTU, vpnRoute, netlinker) + err = setTCPMSSOnVPNRoutes(vpnLinkMTU, vpnRoutes, netlinker) if err != nil { err = fmt.Errorf("setting safe TCP MSS for MTU %d: %w", vpnLinkMTU, err) vpnLinkMTU = originalMTU @@ -224,14 +224,20 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, return nil } -func setTCPMSSOnVPNRoute(mtu uint32, route netlink.Route, netlinker NetLinker) error { - ipHeaderLength := pconstants.IPv4HeaderLength - if route.Dst.Addr().Is6() { - ipHeaderLength = pconstants.IPv6HeaderLength +func setTCPMSSOnVPNRoutes(mtu uint32, routes []netlink.Route, netlinker NetLinker) error { + for _, route := range routes { + ipHeaderLength := pconstants.IPv4HeaderLength + if route.Dst.Addr().Is6() { + ipHeaderLength = pconstants.IPv6HeaderLength + } + const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps + overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead + mss := mtu - overhead + route.AdvMSS = mss + err := netlinker.RouteReplace(route) + if err != nil { + return fmt.Errorf("replacing route %v: %w", route, err) + } } - const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps - overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead - mss := mtu - overhead - route.AdvMSS = mss - return netlinker.RouteReplace(route) + return nil }