Path MTU discovery fixes and improvements (#3109)

- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used
- New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443`
- ICMP PMTUD now targets external-by-default IP addresses
- New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism.
- Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆
- Fix #3108
This commit is contained in:
Quentin McGaw
2026-02-15 01:40:34 +01:00
committed by GitHub
parent 8f1fda7646
commit be92aa2ac4
59 changed files with 2050 additions and 376 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl
RUN apk add wireguard-tools htop openssl tcpdump
+1
View File
@@ -22,6 +22,7 @@ linters:
- "^disabled$"
# Firewall and routing strings
- "^(ACCEPT|DROP)$"
- "^--append$"
- "^--delete$"
- "^all$"
- "^(tcp|udp)$"
+4 -1
View File
@@ -110,8 +110,11 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU=1320 \
WIREGUARD_MTU= \
WIREGUARD_IMPLEMENTATION=auto \
# PMTUD
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443 \
# VPN server filtering
SERVER_REGIONS= \
SERVER_COUNTRIES= \
+108
View File
@@ -0,0 +1,108 @@
package settings
import (
"errors"
"fmt"
"net/netip"
"strings"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// PMTUD contains settings to configure Path MTU Discovery.
type PMTUD struct {
// ICMPAddresses is the redundancy list of addresses to use
// for ICMP path MTU discovery. Each address MUST handle ICMP
// packets for PMTUD to work.
// It cannot be nil in the internal state.
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
// TCPAddresses is the redundancy list of addresses to use
// for TCP path MTU discovery. Each address MUST have a listening
// TCP server on the port specified.
// It cannot be nil in the internal state.
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
}
var (
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
)
// Validate validates PMTUD settings.
func (p PMTUD) validate() (err error) {
for i, addr := range p.ICMPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
}
}
for i, addr := range p.TCPAddresses {
if !addr.IsValid() {
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
}
}
return nil
}
func (p *PMTUD) copy() (copied PMTUD) {
return PMTUD{
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
}
}
func (p *PMTUD) overrideWith(other PMTUD) {
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
}
func (p *PMTUD) setDefaults() {
defaultICMPAddresses := []netip.Addr{
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
}
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
const tlsPort = 443
defaultTCPAddresses := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
}
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
}
func (p PMTUD) String() string {
return p.toLinesNode().String()
}
func (p PMTUD) toLinesNode() (node *gotree.Node) {
node = gotree.New("Path MTU discovery:")
addrs := make([]string, len(p.ICMPAddresses))
for i, addr := range p.ICMPAddresses {
addrs[i] = addr.String()
}
node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", "))
addrs = make([]string, len(p.TCPAddresses))
for i, addr := range p.TCPAddresses {
addrs[i] = addr.String()
}
node.Appendf("TCP addresses: %s", strings.Join(addrs, ", "))
return node
}
func (p *PMTUD) read(r *reader.Reader) (err error) {
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
if err != nil {
return err
}
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
if err != nil {
return err
}
return nil
}
@@ -29,14 +29,17 @@ func Test_Settings_String(t *testing.T) {
| | └── OpenVPN server selection settings:
| | ├── Protocol: UDP
| | └── Private Internet Access encryption preset: strong
| ── OpenVPN settings:
| ├── OpenVPN version: 2.6
| ├── User: [not set]
| ├── Password: [not set]
| ├── Private Internet Access encryption preset: strong
| ├── Network interface: tun0
| ├── Run OpenVPN as: root
| └── Verbosity level: 1
| ── OpenVPN settings:
| | ├── OpenVPN version: 2.6
| | ├── User: [not set]
| | ├── Password: [not set]
| | ├── Private Internet Access encryption preset: strong
| | ├── Network interface: tun0
| | ├── Run OpenVPN as: root
| | └── Verbosity level: 1
| └── Path MTU discovery:
| ├── ICMP addresses: 1.1.1.1, 8.8.8.8
| └── TCP addresses: 1.1.1.1:443, 8.8.8.8:443
├── DNS settings:
| ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1
+15
View File
@@ -18,6 +18,7 @@ type VPN struct {
Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
}
// TODO v4 remove pointer for receiver (because of Surfshark).
@@ -45,6 +46,11 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
}
}
err = v.PMTUD.validate()
if err != nil {
return fmt.Errorf("PMTUD settings: %w", err)
}
return nil
}
@@ -54,6 +60,7 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
}
}
@@ -62,6 +69,7 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
}
func (v *VPN) setDefaults() {
@@ -69,6 +77,7 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
}
func (v VPN) String() string {
@@ -85,6 +94,7 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
} else {
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
return node
}
@@ -107,5 +117,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err)
}
err = v.PMTUD.read(r)
if err != nil {
return fmt.Errorf("PMTUD: %w", err)
}
return nil
}
+10 -15
View File
@@ -38,15 +38,9 @@ type Wireguard struct {
Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be zero in the internal state, and defaults to
// 1320. Note it is not the wireguard-go MTU default of 1420
// because this impacts bandwidth a lot on some VPN providers,
// see https://github.com/qdm12/gluetun/issues/1650.
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint32 `json:"mtu"`
// It cannot be nil in the internal state, and defaults to
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -195,8 +189,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
const defaultMTU = 1320
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
}
@@ -232,7 +225,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU)
if *w.MTU == 0 {
interfaceNode.Append("MTU: use path MTU discovery")
} else {
interfaceNode.Appendf("MTU: %d", *w.MTU)
}
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
@@ -273,11 +270,9 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
w.MTU = *mtuPtr
}
return nil
}
+17 -14
View File
@@ -29,17 +29,16 @@ func appendOrDelete(remove bool) string {
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
fields := strings.Fields(rule)
for i, field := range fields {
switch field {
case "-A", "--append":
fields[i] = "--delete"
case "-D", "--delete":
fields[i] = "--append"
}
}
return rule
return strings.Join(fields, " ")
}
// Version obtains the version of the installed iptables.
@@ -86,10 +85,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
}
func (c *Config) clearAllRules(ctx context.Context) error {
return c.runMixedIptablesInstructions(ctx, []string{
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
tables := []string{"filter"}
for _, table := range tables {
return c.runMixedIptablesInstructions(ctx, []string{
"-t " + table + " --flush", // flush all chains
"-t " + table + " --delete-chain", // delete all chains
})
}
return nil
}
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
+13 -1
View File
@@ -18,6 +18,7 @@ type Route struct {
Type uint8
Scope uint8
Proto uint8
AdvMSS uint32
}
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
@@ -35,6 +36,9 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
if metrics := message.Attributes.Metrics; metrics != nil {
r.AdvMSS = metrics.AdvMSS
}
}
func (r Route) message() *rtnetlink.RouteMessage {
@@ -58,7 +62,6 @@ func (r Route) message() *rtnetlink.RouteMessage {
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Dst: *dst, // there should always be a dst for routes
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
@@ -67,6 +70,15 @@ func (r Route) message() *rtnetlink.RouteMessage {
if src != nil { // src is optional
message.Attributes.Src = *src
}
if dst != nil {
message.Attributes.Dst = *dst
}
if r.AdvMSS != 0 {
if message.Attributes.Metrics == nil {
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
}
message.Attributes.Metrics.AdvMSS = r.AdvMSS
}
return message
}
+24
View File
@@ -0,0 +1,24 @@
package constants
const (
MaxEthernetFrameSize uint32 = 1500
// MinIPv4MTU is defined according to
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
MinIPv4MTU uint32 = 68
MinIPv6MTU uint32 = 1280
IPv4HeaderLength uint32 = 20
IPv6HeaderLength uint32 = 40
UDPHeaderLength uint32 = 8
// BaseTCPHeaderLength is the TCP header length without options,
// which is the minimum TCP header length.
BaseTCPHeaderLength uint32 = 20
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
MaxTCPHeaderLength uint32 = 60
WireguardHeaderLength uint32 = 32
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
8 + // session id
4 + // packet id
28 // max possible auth tag/iv
)
-29
View File
@@ -1,29 +0,0 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"net"
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"bytes"
@@ -9,17 +9,17 @@ import (
)
var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
return nil
}
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
ErrEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
ErrEchoDataMismatch, sent, received)
}
return nil
}
@@ -1,6 +1,6 @@
//go:build !linux && !windows
package pmtud
package icmp
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"syscall"
@@ -1,6 +1,4 @@
//go:build windows
package pmtud
package icmp
import (
"syscall"
+30
View File
@@ -0,0 +1,30 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
+53
View File
@@ -0,0 +1,53 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// PathMTUDiscover discovers the path MTU to the given IP address
// using ICMP.
// It first tries to get the next hop MTU using ICMP messages.
// If that fails, it falls back to sending echo requests with
// different packet sizes to find the maximum MTU.
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := constants.MinIPv4MTU
if ip.Is6() {
minMTU = constants.MinIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
}
+7
View File
@@ -0,0 +1,7 @@
package icmp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"context"
@@ -11,14 +11,13 @@ import (
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU uint32 = 68
icmpv4Protocol int = 1
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
@@ -38,7 +37,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
@@ -83,7 +82,9 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
buffer := make([]byte, physicalLinkMTU)
for { // for loop in case we read an echo reply for another ICMP request
// for loop in case we read an ICMP message from another ICMP request
// or TCP/UDP traffic triggering an ICMP response.
for {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
@@ -108,24 +109,27 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const portUnreachable = 3
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
ErrDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
@@ -153,7 +157,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"context"
@@ -8,12 +8,12 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
@@ -23,7 +23,7 @@ func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
@@ -85,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
@@ -103,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
@@ -116,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
cryptorand "crypto/rand"
+187
View File
@@ -0,0 +1,187 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
"golang.org/x/net/icmp"
)
type icmpTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
tests := make([]icmpTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
func collectReplies(conn net.PacketConn, ipVersion string,
tests []icmpTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}
+73
View File
@@ -0,0 +1,73 @@
package ip
import (
"encoding/binary"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
ipHeader := make([]byte, constants.IPv4HeaderLength)
const version byte = 4
const headerLength byte = 20 / 4 // in 32-bit words
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
ipHeader[1] = 0 // type of Service
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
ipHeader[4], ipHeader[5] = 0, 0 // identification
const flagsAndOffset uint16 = 0x4000 // DF bit set
putUint16(ipHeader[6:], flagsAndOffset)
ipHeader[8] = 64 // ttl
ipHeader[9] = syscall.IPPROTO_TCP
srcIPBytes := srcIP.As4()
copy(ipHeader[12:16], srcIPBytes[:])
dstIPBytes := dstIP.As4()
copy(ipHeader[16:20], dstIPBytes[:])
checksum := ipChecksum(ipHeader)
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
return ipHeader
}
// ipChecksum calculates the checksum for the IP header.
//
//nolint:mnd
func ipChecksum(header []byte) uint16 {
sum := uint32(0)
for i := 0; i < len(header)-1; i += 2 {
sum += uint32(header[i])<<8 + uint32(header[i+1])
}
if len(header)%2 != 0 {
sum += uint32(header[len(header)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
// HeaderV6 makes an IPv6 header.
// payloadLen is the length of the payload following the header.
// nextHeader can be byte([syscall.IPPROTO_TCP]) for example.
func HeaderV6(srcIP, dstIP netip.Addr,
payloadLen uint16, nextHeader byte,
) []byte {
ipv6Header := make([]byte, constants.IPv6HeaderLength)
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
// Flow Label (remaining 16 bits)
ipv6Header[2] = 0x00
ipv6Header[3] = 0x00
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
ipv6Header[6] = nextHeader
const hopLimit = 64
ipv6Header[7] = hopLimit
copy(ipv6Header[8:24], srcIP.AsSlice())
copy(ipv6Header[24:40], dstIP.AsSlice())
return ipv6Header
}
+9
View File
@@ -0,0 +1,9 @@
package ip
import (
"encoding/binary"
)
func putUint16(b []byte, v uint16) {
binary.NativeEndian.PutUint16(b, v)
}
@@ -0,0 +1,9 @@
//go:build !darwin
package ip
import "encoding/binary"
func putUint16(b []byte, v uint16) {
binary.BigEndian.PutUint16(b, v)
}
+9
View File
@@ -0,0 +1,9 @@
//go:build linux || darwin
package ip
import "syscall"
func SetIPv4HeaderIncluded(fd int) error {
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows && !darwin
package ip
func SetIPv4HeaderIncluded(fd int) error {
panic("not implemented")
}
+12
View File
@@ -0,0 +1,12 @@
package ip
import (
"syscall"
"golang.org/x/sys/windows"
)
func SetIPv4HeaderIncluded(handle syscall.Handle) error {
const ipHdrIncluded = windows.IP_HDRINCL
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, ipHdrIncluded, 1)
}
+5
View File
@@ -0,0 +1,5 @@
package ip
func SetIPv6HeaderIncluded(fd int) error {
panic("darwin does not allow an application to build IPv6 headers")
}
+8
View File
@@ -0,0 +1,8 @@
package ip
import "syscall"
func SetIPv6HeaderIncluded(fd int) error {
const ipv6HdrIncluded = 36 // IPV6_HDRINCL
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, ipv6HdrIncluded, 1)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows && !darwin
package ip
func SetIPv6HeaderIncluded(fd int) error {
panic("not implemented")
}
+7
View File
@@ -0,0 +1,7 @@
package ip
import "syscall"
func SetIPv6HeaderIncluded(fd syscall.Handle) error {
panic("windows does not allow an application to build IPv6 headers")
}
+123
View File
@@ -0,0 +1,123 @@
package ip
import (
"fmt"
"net/netip"
"syscall"
"github.com/jsimonetti/rtnetlink"
)
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
// specified destination. It also reserves an ephemeral source port for the specified protocol
// to ensure that the port is not used by other processes. The cleanup function returned should
// be called to release the reserved port when done.
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
srcAddr, err := srcIP(dst.Addr())
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
}
srcPort, cleanup, err := srcPort(srcAddr, proto)
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
}
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
}
var errNoRoute = fmt.Errorf("no route to destination")
func srcIP(dst netip.Addr) (netip.Addr, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return netip.Addr{}, err
}
defer conn.Close()
family := uint8(syscall.AF_INET)
if dst.Is6() {
family = syscall.AF_INET6
}
// Request route to destination
requestMessage := &rtnetlink.RouteMessage{
Family: family,
Attributes: rtnetlink.RouteAttributes{
Dst: dst.AsSlice(),
},
}
messages, err := conn.Route.Get(requestMessage)
if err != nil {
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
}
for _, message := range messages {
if message.Attributes.Src == nil {
continue
}
ipv6 := message.Attributes.Src.To4() == nil
if ipv6 {
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
}
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
}
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
}
// srcPort reserves an ephemeral source port by opening a socket for the
// protocol specified and binds it to the provided source address.
// It doesn't actually listen on the port.
// The cleanup function returned should be called to release the port when done.
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
family := syscall.AF_INET
if srcAddr.Is6() {
family = syscall.AF_INET6
}
fd, err := syscall.Socket(family, syscall.SOCK_STREAM, proto)
if err != nil {
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
}
cleanup = func() {
_ = syscall.Close(fd)
}
// Bind to port 0 to get an ephemeral port
const port = 0
var bindAddr syscall.Sockaddr
if srcAddr.Is4() {
bindAddr = &syscall.SockaddrInet4{
Port: port,
Addr: srcAddr.As4(),
}
} else {
bindAddr = &syscall.SockaddrInet6{
Port: port,
Addr: srcAddr.As16(),
}
}
err = syscall.Bind(fd, bindAddr)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
}
sockAddr, err := syscall.Getsockname(fd)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("getting bound socket name: %w", err)
}
switch typedSockAddr := sockAddr.(type) {
case *syscall.SockaddrInet4:
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
case *syscall.SockaddrInet6:
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
default:
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
}
return srcPort, cleanup, nil
}
+38 -233
View File
@@ -4,268 +4,73 @@ import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/icmp"
"github.com/qdm12/gluetun/internal/pmtud/tcp"
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
// whole range of possible MTU values up to the physical link MTU to check.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) (
mtu uint32, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if pingTimeout == 0 {
pingTimeout = time.Second
if tryTimeout == 0 {
tryTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
// Try finding the MTU using ICMP
maxPossibleMTU := physicalLinkMTU
icmpSuccess := false
for _, icmpIP := range icmpAddrs {
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
tryTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
icmpSuccess = true
maxPossibleMTU = mtu
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
logger.Debugf("ICMP path MTU discovery failed: %s", err)
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
for _, addrPort := range tcpAddrs {
minMTU := constants.MinIPv4MTU
if addrPort.Addr().Is6() {
minMTU = constants.MinIPv6MTU
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if icmpSuccess {
const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
}
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
return mtu, nil
}
return nil
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
}
+7
View File
@@ -0,0 +1,7 @@
package tcp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
@@ -0,0 +1,3 @@
package tcp
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
+80
View File
@@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
// Package tcp is a generated GoMock package.
package tcp
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Debugf mocks base method.
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
var ErrMTUNotFound = errors.New("MTU not found")
type testUnit struct {
mtu uint32
ok bool
}
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, logger Logger,
) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
tests := make([]testUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = testUnit{mtu: mtusToTest[i]}
}
family := syscall.AF_INET
if addrPort.Addr().Is6() {
family = syscall.AF_INET6
}
fd, stop, err := startRawSocket(family)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, addrPort.Addr().Is4())
const timeout = time.Second
runCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
errCh := make(chan error)
go func() {
errCh <- tracker.listen(runCtx)
}()
doneCh := make(chan struct{})
for i := range tests {
go func(i int) {
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
tests[i].ok = err == nil
doneCh <- struct{}{}
}(i)
}
for range tests {
select {
case <-doneCh:
case err := <-errCh:
if err == nil { // timeout
break
}
return 0, fmt.Errorf("listening for TCP replies: %w", err)
}
}
if tests[len(tests)-1].ok {
return tests[len(tests)-1].mtu, nil
}
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
stop()
cancel()
return PathMTUDiscover(ctx, addrPort,
tests[i].mtu, tests[i+1].mtu-1, logger)
}
}
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"math/rand/v2"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
// However, in some cases where the server closes the connection with RST immediately,
// it can be useful to add some data payload to a SYN packet and check if the server still
// replies. Only set mtu to a non zero value if you know what you are doing.
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
seq = rand.Uint32() //nolint:gosec
const ack = 0 // SYN has no ACK number
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
}
// createACKPacket creates a TCP ACK packet.
// If the mtu is set to 0, no payload is sent.
// Otherwise, the payload is calculated to test the MTU given.
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
const flags = ackFlag | pshFlag
return createPacket(src, dst, seq, ack, payloadLength, flags)
}
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
const payloadLength = constants.BaseTCPHeaderLength // no data payload
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
}
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
var ipHeaderLength uint32
if dst.Addr().Is4() {
ipHeaderLength = constants.IPv4HeaderLength
} else {
ipHeaderLength = constants.IPv6HeaderLength
}
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
panic("MTU too small to hold IP and TCP headers")
}
return mtu - ipHeaderLength
}
func createPacket(src, dst netip.AddrPort,
seq, ack, payloadLength uint32, flags byte,
) []byte {
if payloadLength < constants.BaseTCPHeaderLength {
panic("payload length is too small to hold TCP header")
}
var ipHeader []byte
if dst.Addr().Is4() {
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
} else {
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
}
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
// data is just zeroes
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
var data []byte
if dataLength > 0 {
data = make([]byte, dataLength)
}
checksum := tcpChecksum(ipHeader, tcpHeader, data)
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
copy(packet, ipHeader)
copy(packet[len(ipHeader):], tcpHeader)
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
return packet
}
+196
View File
@@ -0,0 +1,196 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
if err != nil {
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
}
if family == syscall.AF_INET {
err = ip.SetIPv4HeaderIncluded(fdPlatform)
} else {
err = ip.SetIPv6HeaderIncluded(fdPlatform)
}
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
}
// Allow sending packets larger than cached PMTU (for PMTUD probing)
err = setMTUDiscovery(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
}
// use polling because some Linux systems do not cancel
// blocking syscalls such as recvfrom when the socket is closed,
// which would cause things to hang indefinitely.
err = setNonBlock(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
}
stop = func() {
_ = syscall.Close(fdPlatform)
}
return fileDescriptor(fdPlatform), stop, nil
}
var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
)
// Craft and send a raw TCP packet to test the MTU.
// It expects either an RST reply (if no server is listening)
// or a SYN-ACK/ACK reply (if a server is listening).
func runTest(ctx context.Context, fd fileDescriptor,
tracker *tracker, dst netip.AddrPort, mtu uint32,
) error {
const proto = syscall.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
if err != nil {
return fmt.Errorf("getting source address: %w", err)
}
defer cleanup()
ch := make(chan []byte)
abort := make(chan struct{})
defer close(abort)
tracker.register(src.Port(), dst.Port(), ch, abort)
defer tracker.unregister(src.Port(), dst.Port())
dstSockAddr := makeSockAddr(dst)
synPacket, synSeq := createSYNPacket(src, dst, 0)
const sendToFlags = 0
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending SYN packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
switch {
case err != nil:
return fmt.Errorf("parsing first reply TCP header: %w", err)
case packetType == packetTypeRST:
// server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case packetType != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
case synAckAck != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
}
// Send a no-data ACK packet to finish the 3-way handshake.
const ackMTU = 0 // no data payload initially
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-without-data packet: %w", err)
}
// Send a data ACK packet to test the MTU given.
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-with-data packet: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing second reply TCP header: %w", err)
}
switch packetType { //nolint:exhaustive
case packetTypeRST:
return nil
case packetTypeACK:
err = sendRST(fd, src, dst, ack)
if err != nil {
return fmt.Errorf("sending RST packet: %w", err)
}
return nil
default:
_ = sendRST(fd, src, dst, ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
}
}
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
if addr.Addr().Is4() {
return &syscall.SockaddrInet4{
Port: int(addr.Port()),
Addr: addr.Addr().As4(),
}
}
return &syscall.SockaddrInet6{
Port: int(addr.Port()),
Addr: addr.Addr().As16(),
}
}
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
src, dst netip.AddrPort, mtu uint32,
) error {
packet, _ := createSYNPacket(src, dst, mtu)
const sendToFlags = 0
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
if err != nil {
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err() // timeout: the MTU test SYN packet was too big
case reply = <-ch:
}
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing reply TCP header: %w", err)
} else if packetType != packetTypeRST {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
}
return nil
}
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
previousACK uint32,
) error {
seq := previousACK
const ack = 0
rstPacket := createRSTPacket(src, dst, seq, ack)
const sendToFlags = 0
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
}
+5
View File
@@ -0,0 +1,5 @@
package tcp
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
return reply, true
}
+7
View File
@@ -0,0 +1,7 @@
package tcp
import "syscall"
func setMTUDiscovery(fd int) error {
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
+30
View File
@@ -0,0 +1,30 @@
//go:build !darwin
package tcp
import (
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
if len(reply) < int(constants.IPv4HeaderLength) {
return nil, false // not an IPv4 packet
}
version := reply[0] >> 4 //nolint:mnd
const ipv4Version = 4
if version != ipv4Version {
return nil, false
}
// For IPv4 we need to skip the IP header, which is at least
// 20B and can be up to 60B.
// The Internet Header Length is the lower 4 bits of the first byte and
// represents the number of 32-bit words of the header length.
const ihlMask byte = 0x0F
const bytesInWord = 4
headerLength := int((reply[0] & ihlMask)) * bytesInWord
if len(reply) < headerLength {
return nil, false // not enough data for full IPv4 header
}
return reply[headerLength:], true
}
+199
View File
@@ -0,0 +1,199 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"testing"
"time"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_runTest(t *testing.T) {
t.Parallel()
noopLogger := &noopLogger{}
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = syscall.AF_INET
fd, stop, err := startRawSocket(family)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
testCases := map[string]struct {
timeout time.Duration
dst func(t *testing.T) netip.AddrPort
mtu uint32
success bool
}{
"local_not_listening": {
timeout: time.Hour,
dst: func(t *testing.T) netip.AddrPort {
t.Helper()
port := reserveClosedPort(t)
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
},
mtu: loopbackMTU,
success: true,
},
"remote_not_listening": {
timeout: 50 * time.Millisecond,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
},
mtu: defaultIPv4MTU,
},
"1.1.1.1:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
"1.1.1.1:80": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
},
mtu: defaultIPv4MTU,
success: true,
},
"8.8.8.8:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel()
dst := testCase.dst(t)
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
if testCase.success {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
}
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
for _, route := range defaultRoutes {
if route.Family != netlink.FamilyV4 {
continue
}
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
return link.MTU, nil
}
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := syscall.Close(fd)
assert.NoError(t, err)
})
addr := &syscall.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = syscall.Bind(fd, addr)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr, err := syscall.Getsockname(fd)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
if !ok {
_ = syscall.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Error(_ string) {}
+28
View File
@@ -0,0 +1,28 @@
//go:build linux || darwin
package tcp
import (
"syscall"
"time"
)
// fileDescriptor is a platform-independent type for socket file descriptors.
type fileDescriptor int
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(int(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(int(fd), p, flags)
}
func setNonBlock(fd int) error {
return syscall.SetNonblock(fd, true)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows
package tcp
func setMTUDiscovery(fd int) error {
panic("not implemented")
}
+37
View File
@@ -0,0 +1,37 @@
package tcp
import (
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type fileDescriptor syscall.Handle
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := int(timeout.Milliseconds())
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
}
func setMTUDiscovery(fd syscall.Handle) error {
panic("not implemented")
}
func setNonBlock(fd syscall.Handle) error {
// Windows: Use ioctlsocket with FIONBIO
var arg uint32 = 1 // 1 to enable non-blocking mode
var bytesReturned uint32
const FIONBIO = 0x8004667e
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
}
+124
View File
@@ -0,0 +1,124 @@
package tcp
import (
"encoding/binary"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// For SYN, ack is 0.
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
header := make([]byte, constants.BaseTCPHeaderLength)
binary.BigEndian.PutUint16(header[0:], srcPort)
binary.BigEndian.PutUint16(header[2:], dstPort)
binary.BigEndian.PutUint32(header[4:], seq)
binary.BigEndian.PutUint32(header[8:], ack)
//nolint:mnd
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
header[13] = flags
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
const windowSize = 5840
binary.BigEndian.PutUint16(header[14:], windowSize)
// header[16:17] is the checksum, set later
// header[18:19] is urgent pointer, not needed for our use case
return header
}
//nolint:mnd
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
var pseudoHeader []byte
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
if isIPv6 {
pseudoHeader = make([]byte, 40)
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
pseudoHeader[39] = 6 // Next Header (TCP)
} else {
pseudoHeader = make([]byte, 12)
copy(pseudoHeader[0:4], ipHeader[12:16])
copy(pseudoHeader[4:8], ipHeader[16:20])
pseudoHeader[9] = 6
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
}
sum := uint32(0)
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
for i := 0; i < len(slice)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
}
if len(slice)%2 != 0 {
sum += uint32(slice[len(slice)-1]) << 8
}
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
const (
tcpFlagsOffset = 13
rstFlag byte = 0x04
synFlag byte = 0x02
ackFlag byte = 0x10
pshFlag byte = 0x08
)
type packetType uint8
const (
packetTypeSYN packetType = iota + 1
packetTypeSYNACK
packetTypeACK
packetTypeRST
)
func (p packetType) String() string {
switch p {
case packetTypeSYN:
return "SYN"
case packetTypeSYNACK:
return "SYN-ACK"
case packetTypeACK:
return "ACK"
case packetTypeRST:
return "RST"
default:
panic("unknown packet type")
}
}
var (
errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
)
// parseTCPHeader parses some elements from the TCP header.
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
if len(header) < int(constants.BaseTCPHeaderLength) {
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
}
flags := header[tcpFlagsOffset]
switch {
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
packetType = packetTypeSYN
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
packetType = packetTypeSYNACK
case (flags & rstFlag) != 0:
packetType = packetTypeRST
case (flags & ackFlag) != 0:
packetType = packetTypeACK
default:
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
}
seq = binary.BigEndian.Uint32(header[4:8])
ack = binary.BigEndian.Uint32(header[8:12])
return packetType, seq, ack, nil
}
+134
View File
@@ -0,0 +1,134 @@
package tcp
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
type tracker struct {
fd fileDescriptor
ipv4 bool
mutex sync.RWMutex
portsToDispatch map[uint32]dispatch
}
type dispatch struct {
replyCh chan<- []byte
abort <-chan struct{}
}
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
return &tracker{
fd: fd,
ipv4: ipv4,
portsToDispatch: make(map[uint32]dispatch),
}
}
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
buf := make([]byte, 4) //nolint:mnd
binary.BigEndian.PutUint16(buf[0:2], localPort)
binary.BigEndian.PutUint16(buf[2:4], remotePort)
return binary.BigEndian.Uint32(buf)
}
func (t *tracker) register(localPort, remotePort uint16,
ch chan<- []byte, abort <-chan struct{},
) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
t.portsToDispatch[key] = dispatch{
replyCh: ch,
abort: abort,
}
}
func (t *tracker) unregister(localPort, remotePort uint16) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
delete(t.portsToDispatch, key)
}
// listen listens for incoming TCP packets and dispatches them to the
// correct channel based on the source and destination port.
// If the context has a deadline associated, this one is used on the socket.
// Note it returns a nil error on context cancellation.
func (t *tracker) listen(ctx context.Context) error {
deadline, hasDeadline := ctx.Deadline()
for ctx.Err() == nil {
if hasDeadline {
remaining := time.Until(deadline)
if remaining <= 0 {
return nil
}
err := setSocketTimeout(t.fd, remaining)
if err != nil {
return fmt.Errorf("setting socket receive timeout: %w", err)
}
}
reply := make([]byte, constants.MaxEthernetFrameSize)
n, _, err := recvFrom(t.fd, reply, 0)
if err != nil {
switch {
case errors.Is(err, syscall.EAGAIN),
errors.Is(err, syscall.EWOULDBLOCK):
pollSleep(ctx)
continue
case ctx.Err() != nil:
// context canceled, stop listening so exit cleanly with no error
return nil //nolint:nilerr
default:
return fmt.Errorf("receiving on socket: %w", err)
}
}
reply = reply[:n]
if t.ipv4 {
var ok bool
reply, ok = stripIPv4Header(reply)
if !ok {
continue // not an IPv4 packet
}
}
const minTCPHeaderLength = 20
if len(reply) < minTCPHeaderLength {
continue
}
srcPort := binary.BigEndian.Uint16(reply[0:2])
dstPort := binary.BigEndian.Uint16(reply[2:4])
key := t.constructKey(dstPort, srcPort)
t.mutex.RLock()
dispatch, exists := t.portsToDispatch[key]
t.mutex.RUnlock()
if !exists {
continue
}
select {
case dispatch.replyCh <- reply:
case <-dispatch.abort:
}
}
return nil
}
func pollSleep(ctx context.Context) {
const sleepBetweenPolls = 10 * time.Millisecond
timer := time.NewTimer(sleepBetweenPolls)
select {
case <-ctx.Done():
timer.Stop()
case <-timer.C:
}
}
+36
View File
@@ -0,0 +1,36 @@
package test
import "math"
// MakeMTUsToTest determines a slice of MTU values to test
// between minMTU and maxMTU inclusive. It creates an MTU
// slice of length up to 11 MTUs such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func MakeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
@@ -1,4 +1,4 @@
package pmtud
package test
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_makeMTUsToTest(t *testing.T) {
func Test_MakeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -48,7 +48,7 @@ func Test_makeMTUsToTest(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
+40
View File
@@ -0,0 +1,40 @@
package pmtud
import (
"net/netip"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
)
// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel
// given the VPN type, network protocol, and VPN gateway IP address.
// This is notably useful to skip testing MTU values higher than this value.
// The function panics if the network or VPN type is unknown.
func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 {
const physicalLinkMTU = pconstants.MaxEthernetFrameSize
vpnLinkMTU := physicalLinkMTU
if vpnGateway.Is4() {
vpnLinkMTU -= pconstants.IPv4HeaderLength
} else {
vpnLinkMTU -= pconstants.IPv6HeaderLength
}
switch network {
case constants.TCP:
vpnLinkMTU -= pconstants.BaseTCPHeaderLength
case constants.UDP:
vpnLinkMTU -= pconstants.UDPHeaderLength
default:
panic("unknown network protocol: " + network)
}
switch vpnType {
case vpn.Wireguard:
vpnLinkMTU -= pconstants.WireguardHeaderLength
case vpn.OpenVPN:
vpnLinkMTU -= pconstants.OpenVPNHeaderMaxLength
default:
panic("unknown VPN type: " + vpnType)
}
return vpnLinkMTU
}
+11 -1
View File
@@ -16,7 +16,17 @@ func BuildWireguardSettings(connection models.Connection,
settings.PreSharedKey = *userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface
settings.Implementation = userSettings.Implementation
settings.MTU = userSettings.MTU
if *userSettings.MTU > 0 {
settings.MTU = *userSettings.MTU
} else {
// The default is 1320 which is NOT the wireguard-go default
// of 1420 because this impacts bandwidth a lot on some
// VPN providers, see https://github.com/qdm12/gluetun/issues/1650.
// It has been lowered to 1320 following quite a bit of
// investigation in the issue: https://github.com/qdm12/gluetun/issues/2533.
const defaultMTU = 1320
settings.MTU = defaultMTU
}
settings.IPv6 = &ipv6Supported
const rulePriority = 101 // 100 is to receive external connections
+3 -1
View File
@@ -22,7 +22,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
ipv6Supported bool
settings wireguard.Settings
}{
"some settings": {
"some_settings": {
connection: models.Connection{
IP: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
Port: 51821,
@@ -41,6 +41,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
},
PersistentKeepaliveInterval: ptrTo(time.Hour),
Interface: "wg1",
MTU: ptrTo(uint32(1000)),
},
ipv6Supported: false,
settings: wireguard.Settings{
@@ -58,6 +59,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
PersistentKeepaliveInterval: time.Hour,
RulePriority: 101,
IPv6: boolPtr(false),
MTU: 1000,
},
},
}
+23
View File
@@ -47,3 +47,26 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
}
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
}
var ErrVPNRouteNotFound = errors.New("VPN route not found")
func (r *Routing) VPNRoute(vpnIntf string) (route netlink.Route, err error) {
vpnLink, err := r.netLinker.LinkByName(vpnIntf)
if err != nil {
return route, fmt.Errorf("finding link %s: %w", vpnIntf, err)
}
vpnLinkIndex := vpnLink.Index
routes, err := r.netLinker.RouteList(netlink.FamilyAll)
if err != nil {
return route, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
if route.LinkIndex == vpnLinkIndex &&
!route.Dst.IsValid() {
return route, nil
}
}
return route, fmt.Errorf("%w: for interface %s in %d routes",
ErrVPNRouteNotFound, vpnIntf, len(routes))
}
+2
View File
@@ -21,6 +21,7 @@ type Firewall interface {
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
VPNRoute(vpnIntf string) (route netlink.Route, err error)
}
type PortForward interface {
@@ -67,6 +68,7 @@ type NetLinker interface {
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
+7 -1
View File
@@ -47,7 +47,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue
}
tunnelUpData := tunnelUpData{
vpnType: settings.Type,
pmtud: tunnelUpPMTUDData{
enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0,
vpnType: settings.Type,
network: connection.Protocol,
icmpAddrs: settings.PMTUD.ICMPAddresses,
tcpAddrs: settings.PMTUD.TCPAddresses,
},
serverIP: connection.IP,
serverName: connection.ServerName,
canPortForward: connection.PortForward,
+64 -32
View File
@@ -2,7 +2,6 @@ package vpn
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
@@ -10,6 +9,7 @@ import (
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
)
@@ -17,9 +17,7 @@ import (
type tunnelUpData struct {
// Healthcheck
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
pmtud tunnelUpPMTUDData
// Port forwarding
vpnIntf string
serverName string // used for PIA
@@ -29,6 +27,23 @@ type tunnelUpData struct {
portForwarder PortForwarder
}
type tunnelUpPMTUDData struct {
// enabled is notably false if the user specifies a custom MTU.
enabled bool
// vpnType is used to find the maximum VPN header overhead.
// It can be [vpn.Wireguard] or [vpn.OpenVPN].
vpnType string
// network is used to find the network level header overhead.
// It can be [constants.UDP] or [constants.TCP].
network string
// icmpAddrs is the list of addresses to use for ICMP path MTU discovery.
// Each address should handle ICMP packets for PMTUD to work.
icmpAddrs []netip.Addr
// tcpAddrs is the list of addresses to use for TCP path MTU discovery.
// Each address should have a listening TCP server on the port specified.
tcpAddrs []netip.AddrPort
}
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
l.client.CloseIdleConnections()
@@ -39,11 +54,14 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
}
}
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
l.netLinker, l.routing, mtuLogger)
if err != nil {
mtuLogger.Error(err.Error())
if data.pmtud.enabled {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
l.netLinker, l.routing, mtuLogger)
if err != nil {
mtuLogger.Error(err.Error())
}
}
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
@@ -136,12 +154,11 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
_, _ = l.ApplyStatus(ctx, constants.Running)
}
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
netlinker NetLinker, routing Routing, logger *log.Logger,
) error {
logger.Info("finding maximum MTU, this can take up to 4 seconds")
logger.Info("finding maximum MTU, this can take up to 6 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
@@ -155,18 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
originalMTU := link.MTU
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU uint32 = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, vpnGatewayIP)
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
@@ -178,16 +184,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
switch {
case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
vpnLinkMTU, pingTimeout, logger)
if err != nil {
vpnLinkMTU = originalMTU
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err)
default:
return fmt.Errorf("path MTU discovering: %w", err)
} else {
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
}
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
@@ -195,5 +199,33 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
err = setTCPMSSOnVPNRoute(vpnInterface, vpnLinkMTU, routing, netlinker)
if err != nil {
return fmt.Errorf("setting safe TCP MSS for MTU %d: %w", vpnLinkMTU, err)
}
return nil
}
func setTCPMSSOnVPNRoute(vpnIntf string, mtu uint32,
routing Routing, netlinker NetLinker,
) error {
route, err := routing.VPNRoute(vpnIntf)
if err != nil {
return fmt.Errorf("getting VPN route: %w", err)
}
ipHeaderLength := pconstants.IPv4HeaderLength
if route.Dst.Addr().Is6() {
ipHeaderLength = pconstants.IPv6HeaderLength
}
const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps
overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead
mss := mtu - overhead
route.AdvMSS = mss
err = netlinker.RouteReplace(route)
if err != nil {
return fmt.Errorf("replacing VPN route with MSS changed to %d: %w", mss, err)
}
return nil
}