hotfix(pmtud): detect IPv6 usage in VPN connection

This commit is contained in:
Quentin McGaw
2026-05-09 14:40:04 +00:00
parent 445f99d9dc
commit 5b01324d5f
7 changed files with 69 additions and 24 deletions
+1 -8
View File
@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
@@ -28,10 +27,7 @@ func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(),
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
} }
var ( var errNoRoute = errors.New("no route to destination")
errNoRoute = errors.New("no route to destination")
ErrNetworkUnreachable = errors.New("network unreachable")
)
func srcIP(dst netip.Addr) (netip.Addr, error) { func srcIP(dst netip.Addr) (netip.Addr, error) {
conn, err := rtnetlink.Dial(nil) conn, err := rtnetlink.Dial(nil)
@@ -54,9 +50,6 @@ func srcIP(dst netip.Addr) (netip.Addr, error) {
} }
messages, err := conn.Route.Get(requestMessage) messages, err := conn.Route.Get(requestMessage)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "network is unreachable") {
err = ErrNetworkUnreachable
}
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err) return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
} }
-3
View File
@@ -43,9 +43,6 @@ func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescr
case err != nil: // error already occurred for another findMSS goroutine case err != nil: // error already occurred for another findMSS goroutine
case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing): case errors.Is(result.err, iptables.ErrMarkMatchModuleMissing):
err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err) err = fmt.Errorf("finding MSS for %s: %w", result.dst, result.err)
case dst.Addr().Is6() && errors.Is(result.err, ip.ErrNetworkUnreachable):
// silently discard IPv6 network unreachable errors since they are common
// and expected when the host doesn't have IPv6 connectivity
default: // another error not due to the match module missing default: // another error not due to the match module missing
logger.Debugf("finding MSS for %s failed: %s", result.dst, result.err) logger.Debugf("finding MSS for %s failed: %s", result.dst, result.err)
} }
+3 -5
View File
@@ -1,21 +1,19 @@
package pmtud package pmtud
import ( import (
"net/netip"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
) )
// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel // MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel
// given the VPN type, network protocol, and VPN gateway IP address. // given the VPN type, network protocol, and whether IPv6 is used.
// This is notably useful to skip testing MTU values higher than this value. // This is notably useful to skip testing MTU values higher than this value.
// The function panics if the network or VPN type is unknown. // The function panics if the network or VPN type is unknown.
func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 { func MaxTheoreticalVPNMTU(vpnType, network string, ipv6 bool) uint32 {
const physicalLinkMTU = pconstants.MaxEthernetFrameSize const physicalLinkMTU = pconstants.MaxEthernetFrameSize
vpnLinkMTU := physicalLinkMTU vpnLinkMTU := physicalLinkMTU
if vpnGateway.Is4() { if !ipv6 {
vpnLinkMTU -= pconstants.IPv4HeaderLength vpnLinkMTU -= pconstants.IPv4HeaderLength
} else { } else {
vpnLinkMTU -= pconstants.IPv6HeaderLength vpnLinkMTU -= pconstants.IPv6HeaderLength
+1
View File
@@ -61,6 +61,7 @@ type Storage interface {
} }
type NetLinker interface { type NetLinker interface {
AddrList(linkIndex uint32, family uint8) (addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, addr netip.Prefix) error AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router Router
Ruler Ruler
+48
View File
@@ -0,0 +1,48 @@
package vpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/netlink"
)
func (l *Loop) isIPv6Used(settings settings.VPN) bool {
if !l.ipv6SupportLevel.IsSupported() {
return false
}
switch settings.Type {
case vpn.AmneziaWg:
for _, prefix := range settings.AmneziaWg.Wireguard.Addresses {
if prefix.Addr().Is6() {
return true
}
}
return false
case vpn.OpenVPN:
link, err := l.netLinker.LinkByName(settings.OpenVPN.Interface)
if err != nil {
l.logger.Warnf("assuming IPv6 is not supported, cannot get OpenVPN link by name: %v", err)
return false
}
ipv6Prefixes, err := l.netLinker.AddrList(link.Index, netlink.FamilyV6)
if err != nil {
l.logger.Warnf("assuming IPv6 is not supported, cannot list OpenVPN link addresses: %v", err)
return false
}
for _, prefix := range ipv6Prefixes {
if prefix.Addr().IsGlobalUnicast() && !prefix.Addr().IsPrivate() {
return true
}
}
return false
case vpn.Wireguard:
for _, prefix := range settings.Wireguard.Addresses {
if prefix.Addr().Is6() {
return true
}
}
return false
default:
panic("vpn type not implemented: " + settings.Type)
}
}
+1
View File
@@ -59,6 +59,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0, enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0,
vpnType: settings.Type, vpnType: settings.Type,
network: connection.Protocol, network: connection.Protocol,
ipv6: l.isIPv6Used(settings),
icmpAddrs: settings.PMTUD.ICMPAddresses, icmpAddrs: settings.PMTUD.ICMPAddresses,
tcpAddrs: settings.PMTUD.TCPAddresses, tcpAddrs: settings.PMTUD.TCPAddresses,
}, },
+15 -8
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"strings" "strings"
"time" "time"
@@ -40,6 +41,8 @@ type tunnelUpPMTUDData struct {
// network is used to find the network level header overhead. // network is used to find the network level header overhead.
// It can be [constants.UDP] or [constants.TCP]. // It can be [constants.UDP] or [constants.TCP].
network string network string
// ipv6 is true if the VPN connection supports IPv6.
ipv6 bool
// icmpAddrs is the list of addresses to use for ICMP path MTU discovery. // icmpAddrs is the list of addresses to use for ICMP path MTU discovery.
// Each address should handle ICMP packets for PMTUD to work. // Each address should handle ICMP packets for PMTUD to work.
icmpAddrs []netip.Addr icmpAddrs []netip.Addr
@@ -69,7 +72,7 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
if data.pmtud.enabled { if data.pmtud.enabled {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType, err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs, data.pmtud.network, data.pmtud.ipv6, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
l.netLinker, l.routing, l.fw, mtuLogger) l.netLinker, l.routing, l.fw, mtuLogger)
if err != nil { if err != nil {
mtuLogger.Error(err.Error()) mtuLogger.Error(err.Error())
@@ -173,16 +176,11 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
} }
func updateToMaxMTU(ctx context.Context, vpnInterface string, func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, vpnType, network string, ipv6 bool, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
netlinker NetLinker, routing Routing, firewall tcp.Firewall, logger *log.Logger, netlinker NetLinker, routing Routing, firewall tcp.Firewall, logger *log.Logger,
) error { ) error {
logger.Info("finding maximum MTU, this can take up to 6 seconds") logger.Info("finding maximum MTU, this can take up to 6 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN gateway IP address: %w", err)
}
vpnRoutes, err := routing.VPNRoutes(vpnInterface) vpnRoutes, err := routing.VPNRoutes(vpnInterface)
if err != nil { if err != nil {
return fmt.Errorf("getting VPN routes: %w", err) return fmt.Errorf("getting VPN routes: %w", err)
@@ -195,7 +193,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
originalMTU := link.MTU originalMTU := link.MTU
vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, vpnGatewayIP) vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, ipv6)
// Setting the VPN link MTU to 1500 might interrupt the connection until // Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU. // the new MTU is set again, but this is necessary to find the highest valid MTU.
@@ -206,6 +204,15 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
} }
if !ipv6 {
icmpAddrs = slices.DeleteFunc(icmpAddrs, func(addr netip.Addr) bool {
return addr.Is6()
})
tcpAddrs = slices.DeleteFunc(tcpAddrs, func(addr netip.AddrPort) bool {
return addr.Addr().Is6()
})
}
const pingTimeout = time.Second const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs, vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
vpnLinkMTU, pingTimeout, firewall, logger) vpnLinkMTU, pingTimeout, firewall, logger)