diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index df017012..08568d12 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -581,6 +581,7 @@ type Linker interface { LinkDel(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) + LinkSetMTU(link netlink.Link, mtu int) error } type clier interface { diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index c861e8e5..3edf133d 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -45,6 +45,7 @@ type Wireguard struct { // It has been lowered to 1320 following quite a bit of // investigation in the issue: // https://github.com/qdm12/gluetun/issues/2533. + // Note this should now be replaced with the PMTUD feature. MTU uint16 `json:"mtu"` // Implementation is the Wireguard implementation to use. // It can be "auto", "userspace" or "kernelspace". diff --git a/internal/netlink/link.go b/internal/netlink/link.go index d810e47e..b2c96134 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) { return netlink.LinkSetDown(linkToNetlinkLink(&link)) } +func (n *NetLink) LinkSetMTU(link Link, mtu int) error { + return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu) +} + type netlinkLinkImpl struct { attrs *netlink.LinkAttrs linkType string diff --git a/internal/pmtud/apple_ipv4.go b/internal/pmtud/apple_ipv4.go new file mode 100644 index 00000000..6b298d79 --- /dev/null +++ b/internal/pmtud/apple_ipv4.go @@ -0,0 +1,49 @@ +package pmtud + +import ( + "net" + "time" + + "golang.org/x/net/ipv4" +) + +var _ net.PacketConn = &ipv4Wrapper{} + +// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement +// the net.PacketConn interface. It's only used for Darwin or iOS. +type ipv4Wrapper struct { + ipv4Conn *ipv4.PacketConn +} + +func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper { + return &ipv4Wrapper{ipv4Conn: ipv4} +} + +func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, _, addr, err = i.ipv4Conn.ReadFrom(p) + return n, addr, err +} + +func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return i.ipv4Conn.WriteTo(p, nil, addr) +} + +func (i *ipv4Wrapper) Close() error { + return i.ipv4Conn.Close() +} + +func (i *ipv4Wrapper) LocalAddr() net.Addr { + return i.ipv4Conn.LocalAddr() +} + +func (i *ipv4Wrapper) SetDeadline(t time.Time) error { + return i.ipv4Conn.SetDeadline(t) +} + +func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error { + return i.ipv4Conn.SetReadDeadline(t) +} + +func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error { + return i.ipv4Conn.SetWriteDeadline(t) +} diff --git a/internal/pmtud/check.go b/internal/pmtud/check.go new file mode 100644 index 00000000..71f8ff1f --- /dev/null +++ b/internal/pmtud/check.go @@ -0,0 +1,83 @@ +package pmtud + +import ( + "bytes" + "errors" + "fmt" + + "golang.org/x/net/icmp" +) + +var ( + ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low") + ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high") +) + +func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) { + switch { + case mtu < minMTU: + return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu) + case mtu > physicalLinkMTU: + return fmt.Errorf("%w: %d is larger than physical link MTU %d", + ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU) + default: + return nil + } +} + +func checkInvokingReplyIDMatch(icmpProtocol int, received []byte, + outboundMessage *icmp.Message, +) (match bool, err error) { + inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) + if err != nil { + return false, fmt.Errorf("parsing invoking packet: %w", err) + } + inboundBody, ok := inboundMessage.Body.(*icmp.Echo) + if !ok { + return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + } + outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert + return inboundBody.ID == outboundBody.ID, nil +} + +var ErrICMPIDMismatch = errors.New("ICMP id mismatch") + +func checkEchoReply(icmpProtocol int, received []byte, + outboundMessage *icmp.Message, truncatedBody bool, +) (err error) { + inboundMessage, err := icmp.ParseMessage(icmpProtocol, received) + if err != nil { + return fmt.Errorf("parsing invoking packet: %w", err) + } + inboundBody, ok := inboundMessage.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + } + outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert + if inboundBody.ID != outboundBody.ID { + return fmt.Errorf("%w: sent id %d and received id %d", + ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID) + } + err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody) + if err != nil { + return fmt.Errorf("checking sent and received bodies: %w", err) + } + return nil +} + +var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch") + +func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) { + if len(received) > len(sent) { + return fmt.Errorf("%w: sent %d bytes and received %d bytes", + ErrICMPEchoDataMismatch, len(sent), len(received)) + } + if receivedTruncated { + sent = sent[:len(received)] + } + if !bytes.Equal(received, sent) { + return fmt.Errorf("%w: sent %x and received %x", + ErrICMPEchoDataMismatch, sent, received) + } + return nil +} diff --git a/internal/pmtud/df.go b/internal/pmtud/df.go new file mode 100644 index 00000000..9e6ee59d --- /dev/null +++ b/internal/pmtud/df.go @@ -0,0 +1,10 @@ +//go:build !linux && !windows + +package pmtud + +// setDontFragment for platforms other than Linux and Windows +// is not implemented, so we just return assuming the don't +// fragment flag is set on IP packets. +func setDontFragment(fd uintptr) (err error) { + return nil +} diff --git a/internal/pmtud/df_linux.go b/internal/pmtud/df_linux.go new file mode 100644 index 00000000..facf09f1 --- /dev/null +++ b/internal/pmtud/df_linux.go @@ -0,0 +1,12 @@ +//go:build linux + +package pmtud + +import ( + "syscall" +) + +func setDontFragment(fd uintptr) (err error) { + return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, + syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE) +} diff --git a/internal/pmtud/df_windows.go b/internal/pmtud/df_windows.go new file mode 100644 index 00000000..a8c98fc4 --- /dev/null +++ b/internal/pmtud/df_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +package pmtud + +import ( + "syscall" +) + +func setDontFragment(fd uintptr) (err error) { + // https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip + // #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */ + return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1) +} diff --git a/internal/pmtud/errors.go b/internal/pmtud/errors.go new file mode 100644 index 00000000..5f6eaa41 --- /dev/null +++ b/internal/pmtud/errors.go @@ -0,0 +1,29 @@ +package pmtud + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" +) + +var ( + ErrICMPNotPermitted = errors.New("ICMP not permitted") + ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") + ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited") + ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") +) + +func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive + switch { + case strings.HasSuffix(err.Error(), "sendto: operation not permitted"): + err = fmt.Errorf("%w", ErrICMPNotPermitted) + case errors.Is(timedCtx.Err(), context.DeadlineExceeded): + err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout) + case timedCtx.Err() != nil: + err = timedCtx.Err() + } + return err +} diff --git a/internal/pmtud/interfaces.go b/internal/pmtud/interfaces.go new file mode 100644 index 00000000..19146e9c --- /dev/null +++ b/internal/pmtud/interfaces.go @@ -0,0 +1,7 @@ +package pmtud + +type Logger interface { + Debug(msg string) + Debugf(msg string, args ...any) + Warnf(msg string, args ...any) +} diff --git a/internal/pmtud/ipv4.go b/internal/pmtud/ipv4.go new file mode 100644 index 00000000..e835a288 --- /dev/null +++ b/internal/pmtud/ipv4.go @@ -0,0 +1,159 @@ +package pmtud + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "net/netip" + "runtime" + "strings" + "syscall" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" +) + +const ( + // see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media + minIPv4MTU = 68 + icmpv4Protocol = 1 +) + +func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error { + var setDFErr error + err := rawConn.Control(func(fd uintptr) { + setDFErr = setDontFragment(fd) // runs when calling ListenPacket + }) + if err == nil { + err = setDFErr + } + return err + } + + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress) + if err != nil { + if strings.HasSuffix(err.Error(), "socket: operation not permitted") { + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + } + return nil, err + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn)) + } + + return packetConn, nil +} + +func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger, +) (mtu int, err error) { + if ip.Is6() { + panic("IP address is not v4") + } + conn, err := listenICMPv4(ctx) + if err != nil { + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-ctx.Done() + conn.Close() + }() + + // First try to send a packet which is too big to get the maximum MTU + // directly. + outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU) + encodedMessage, err := outboundMessage.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + + buffer := make([]byte, physicalLinkMTU) + + for { // for loop in case we read an echo reply for another ICMP request + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) + return 0, fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + // Side note: echo reply should be at most the number of bytes + // sent, and can be lower, more precisely 576-ipHeader bytes, + // in case the next hop we are reaching replies with a destination + // unreachable and wants to ensure the response makes it way back + // by keeping a low packet size, see: + // https://datatracker.ietf.org/doc/html/rfc1122#page-59 + + inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing message: %w", err) + } + + switch typedBody := inboundMessage.Body.(type) { + case *icmp.DstUnreach: + const fragmentationRequiredAndDFFlagSetCode = 4 + const communicationAdministrativelyProhibitedCode = 13 + switch inboundMessage.Code { + case fragmentationRequiredAndDFFlagSetCode: + case communicationAdministrativelyProhibitedCode: + return 0, fmt.Errorf("%w: %w (code %d)", + ErrICMPDestinationUnreachable, + ErrICMPCommunicationAdministrativelyProhibited, + inboundMessage.Code) + default: + return 0, fmt.Errorf("%w: code %d", + ErrICMPDestinationUnreachable, inboundMessage.Code) + } + + // See https://datatracker.ietf.org/doc/html/rfc1191#section-4 + // Note: the go library does not handle this NextHopMTU section. + nextHopMTU := packetBytes[6:8] + mtu = int(binary.BigEndian.Uint16(nextHopMTU)) + err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU) + if err != nil { + return 0, fmt.Errorf("checking next-hop-mtu found: %w", err) + } + + // The code below is really for sanity checks + packetBytes = packetBytes[8:] + header, err := ipv4.ParseHeader(packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing IPv4 header: %w", err) + } + packetBytes = packetBytes[header.Len:] // truncated original datagram + + const truncated = true + err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated) + if err != nil { + return 0, fmt.Errorf("checking echo reply: %w", err) + } + return mtu, nil + case *icmp.Echo: + inboundID := uint16(typedBody.ID) //nolint:gosec + if inboundID == outboundID { + return physicalLinkMTU, nil + } + logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d", + inboundID, outboundID) + continue + default: + return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + } + } +} diff --git a/internal/pmtud/ipv6.go b/internal/pmtud/ipv6.go new file mode 100644 index 00000000..eeafe4d9 --- /dev/null +++ b/internal/pmtud/ipv6.go @@ -0,0 +1,122 @@ +package pmtud + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv6" +) + +const ( + minIPv6MTU = 1280 + icmpv6Protocol = 58 +) + +func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress) + if err != nil { + if strings.HasSuffix(err.Error(), "socket: operation not permitted") { + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + } + return nil, err + } + return packetConn, nil +} + +func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger, +) (mtu int, err error) { + if ip.Is4() { + panic("IP address is not v6") + } + conn, err := listenICMPv6(ctx) + if err != nil { + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-ctx.Done() + conn.Close() + }() + + // First try to send a packet which is too big to get the maximum MTU + // directly. + outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU) + encodedMessage, err := outboundMessage.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()}) + if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + + buffer := make([]byte, physicalLinkMTU) + + for { // for loop if we encounter another ICMP packet with an unknown id. + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) + return 0, fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + packetBytes = packetBytes[ipv6.HeaderLen:] + + inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes) + if err != nil { + return 0, fmt.Errorf("parsing message: %w", err) + } + + switch typedBody := inboundMessage.Body.(type) { + case *icmp.PacketTooBig: + // https://datatracker.ietf.org/doc/html/rfc1885#section-3.2 + mtu = typedBody.MTU + err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU) + if err != nil { + return 0, fmt.Errorf("checking MTU: %w", err) + } + + // Sanity checks + const truncatedBody = true + err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody) + if err != nil { + return 0, fmt.Errorf("checking invoking message: %w", err) + } + return typedBody.MTU, nil + case *icmp.DstUnreach: + // https://datatracker.ietf.org/doc/html/rfc1885#section-3.1 + idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage) + if err != nil { + return 0, fmt.Errorf("checking invoking message id: %w", err) + } else if idMatch { + return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable) + } + logger.Debug("discarding received ICMP destination unreachable reply with an unknown id") + continue + case *icmp.Echo: + inboundID := uint16(typedBody.ID) //nolint:gosec + if inboundID == outboundID { + return physicalLinkMTU, nil + } + logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d", + inboundID, outboundID) + continue + default: + return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + } + } +} diff --git a/internal/pmtud/message.go b/internal/pmtud/message.go new file mode 100644 index 00000000..f04c7a89 --- /dev/null +++ b/internal/pmtud/message.go @@ -0,0 +1,58 @@ +package pmtud + +import ( + cryptorand "crypto/rand" + "encoding/binary" + "fmt" + "math/rand/v2" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func buildMessageToSend(ipVersion string, mtu int) (id uint16, message *icmp.Message) { + var seed [32]byte + _, _ = cryptorand.Read(seed[:]) + randomSource := rand.NewChaCha8(seed) + + const uint16Bytes = 2 + idBytes := make([]byte, uint16Bytes) + _, _ = randomSource.Read(idBytes) + id = binary.BigEndian.Uint16(idBytes) + + var ipHeaderLength int + var icmpType icmp.Type + switch ipVersion { + case "v4": + ipHeaderLength = ipv4.HeaderLen + icmpType = ipv4.ICMPTypeEcho + case "v6": + ipHeaderLength = ipv6.HeaderLen + icmpType = ipv6.ICMPTypeEchoRequest + default: + panic(fmt.Sprintf("IP version %q not supported", ipVersion)) + } + const pingHeaderLength = 0 + + 1 + // type + 1 + // code + 2 + // checksum + 2 + // identifier + 2 // sequence number + pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength + messageBodyData := make([]byte, pingBodyDataSize) + _, _ = randomSource.Read(messageBodyData) + + // See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types + message = &icmp.Message{ + Type: icmpType, // echo request + Code: 0, // no code + Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6) + Body: &icmp.Echo{ + ID: int(id), + Seq: 0, // only one packet + Data: messageBodyData, + }, + } + return id, message +} diff --git a/internal/pmtud/nooplogger.go b/internal/pmtud/nooplogger.go new file mode 100644 index 00000000..0ab3debf --- /dev/null +++ b/internal/pmtud/nooplogger.go @@ -0,0 +1,7 @@ +package pmtud + +type noopLogger struct{} + +func (noopLogger) Debug(_ string) {} +func (noopLogger) Debugf(_ string, _ ...any) {} +func (noopLogger) Warnf(_ string, _ ...any) {} diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go new file mode 100644 index 00000000..e9cab450 --- /dev/null +++ b/internal/pmtud/pmtud.go @@ -0,0 +1,271 @@ +package pmtud + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "net/netip" + "strings" + "time" + + "golang.org/x/net/icmp" +) + +var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU") + +// PathMTUDiscover discovers the maximum MTU for the path to the given ip address. +// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU. +// If the pingTimeout is zero, it defaults to 1 second. +// If the logger is nil, a no-op logger is used. +// It returns [ErrMTUNotFound] if the MTU could not be determined. +func PathMTUDiscover(ctx context.Context, ip netip.Addr, + physicalLinkMTU int, pingTimeout time.Duration, logger Logger) ( + mtu int, err error, +) { + if physicalLinkMTU == 0 { + const ethernetStandardMTU = 1500 + physicalLinkMTU = ethernetStandardMTU + } + if pingTimeout == 0 { + pingTimeout = time.Second + } + if logger == nil { + logger = &noopLogger{} + } + + if ip.Is4() { + logger.Debug("finding IPv4 next hop MTU") + mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole + default: + return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err) + } + } else { + logger.Debug("requesting IPv6 ICMP packet-too-big reply") + mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed): // blackhole + default: + return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err) + } + } + + // Fall back method: send echo requests with different packet + // sizes and check which ones succeed to find the maximum MTU. + logger.Debug("falling back to sending different sized echo packets") + minMTU := minIPv4MTU + if ip.Is6() { + minMTU = minIPv6MTU + } + return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger) +} + +type pmtudTestUnit struct { + mtu int + echoID uint16 + sentBytes int + ok bool +} + +func pmtudMultiSizes(ctx context.Context, ip netip.Addr, + minMTU, maxPossibleMTU int, pingTimeout time.Duration, + logger Logger, +) (maxMTU int, err error) { + var ipVersion string + var conn net.PacketConn + if ip.Is4() { + ipVersion = "v4" + conn, err = listenICMPv4(ctx) + } else { + ipVersion = "v6" + conn, err = listenICMPv6(ctx) + } + if err != nil { + if strings.HasSuffix(err.Error(), "socket: operation not permitted") { + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + } + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + + mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU) + if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU + return minMTU, nil + } + logger.Debugf("testing the following MTUs: %v", mtusToTest) + + tests := make([]pmtudTestUnit, len(mtusToTest)) + for i := range mtusToTest { + tests[i] = pmtudTestUnit{mtu: mtusToTest[i]} + } + + timedCtx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-timedCtx.Done() + conn.Close() + }() + + for i := range tests { + id, message := buildMessageToSend(ipVersion, tests[i].mtu) + tests[i].echoID = id + + encodedMessage, err := message.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + tests[i].sentBytes = len(encodedMessage) + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { + err = fmt.Errorf("%w", ErrICMPNotPermitted) + } + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + } + + err = collectReplies(conn, ipVersion, tests, logger) + switch { + case err == nil: // max possible MTU is working + return tests[len(tests)-1].mtu, nil + case err != nil && errors.Is(err, net.ErrClosed): + // we have timeouts (IPv4 testing or IPv6 PMTUD blackholes) + // so find the highest MTU which worked. + // Note we start from index len(tests) - 2 since the max MTU + // cannot be working if we had a timeout. + for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd + if tests[i].ok { + return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1, + pingTimeout, logger) + } + } + + // All MTUs failed. + return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound) + case err != nil: + return 0, fmt.Errorf("collecting ICMP echo replies: %w", err) + default: + panic("unreachable") + } +} + +// Create the MTU slice of length 11 such that: +// - the first element is the minMTU +// - the last element is the maxMTU +// - elements in-between are separated as close to each other +// The number 11 is chosen to find the final MTU in 3 searches, +// with a total search space of 1728 MTUs which is enough; +// to find it in 2 searches requires 37 parallel queries which +// could be blocked by firewalls. +func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) { + const mtusLength = 11 // find the final MTU in 3 searches + diff := maxMTU - minMTU + switch { + case minMTU > maxMTU: + panic("minMTU > maxMTU") + case diff <= mtusLength: + mtus = make([]int, 0, diff) + for mtu := minMTU; mtu <= maxMTU; mtu++ { + mtus = append(mtus, mtu) + } + default: + step := float64(diff) / float64(mtusLength-1) + mtus = make([]int, 0, mtusLength) + for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step { + mtus = append(mtus, int(math.Round(mtu))) + } + mtus = append(mtus, maxMTU) // last element is the maxMTU + } + + return mtus +} + +func collectReplies(conn net.PacketConn, ipVersion string, + tests []pmtudTestUnit, logger Logger, +) (err error) { + echoIDToTestIndex := make(map[uint16]int, len(tests)) + for i, test := range tests { + echoIDToTestIndex[test.echoID] = i + } + + // The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would + // create huge buffers which we don't really want to support anyway. + // The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with + // a conventional maximum of 9000 bytes. However, some manufacturers support up + // 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to + // match eventual Jumbo frames. More information at: + // https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media + const maxPossibleMTU = 9196 + buffer := make([]byte, maxPossibleMTU) + + idsFound := 0 + for idsFound < len(tests) { + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + ipPacketLength := len(packetBytes) + + var icmpProtocol int + switch ipVersion { + case "v4": + icmpProtocol = icmpv4Protocol + case "v6": + icmpProtocol = icmpv6Protocol + default: + panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) + } + + // Parse the ICMP message + // Note: this parsing works for a truncated 556 bytes ICMP reply packet. + message, err := icmp.ParseMessage(icmpProtocol, packetBytes) + if err != nil { + return fmt.Errorf("parsing message: %w", err) + } + + echoBody, ok := message.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body) + } + + id := uint16(echoBody.ID) //nolint:gosec + testIndex, testing := echoIDToTestIndex[id] + if !testing { // not an id we expected so ignore it + logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)", + echoBody.ID, message.Type, message.Code, ipPacketLength) + continue + } + idsFound++ + sentBytes := tests[testIndex].sentBytes + + // echo reply should be at most the number of bytes sent, + // and can be lower, more precisely 556 bytes, in case + // the host we are reaching wants to stay out of trouble + // and ensure its echo reply goes through without + // fragmentation, see the following page: + // https://datatracker.ietf.org/doc/html/rfc1122#page-59 + const conservativeReplyLength = 556 + truncated := ipPacketLength < sentBytes && + ipPacketLength == conservativeReplyLength + // Check the packet size is the same if the reply is not truncated + if !truncated && sentBytes != ipPacketLength { + return fmt.Errorf("%w: sent %dB and received %dB", + ErrICMPEchoDataMismatch, sentBytes, ipPacketLength) + } + // Truncated reply or matching reply size + tests[testIndex].ok = true + } + return nil +} diff --git a/internal/pmtud/pmtud_integration_test.go b/internal/pmtud/pmtud_integration_test.go new file mode 100644 index 00000000..468da1e6 --- /dev/null +++ b/internal/pmtud/pmtud_integration_test.go @@ -0,0 +1,22 @@ +//go:build integration + +package pmtud + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func Test_PathMTUDiscover(t *testing.T) { + t.Parallel() + const physicalLinkMTU = 1500 + const timeout = time.Second + mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"), + physicalLinkMTU, timeout, nil) + require.NoError(t, err) + t.Log("MTU found:", mtu) +} diff --git a/internal/pmtud/pmtud_test.go b/internal/pmtud/pmtud_test.go new file mode 100644 index 00000000..db10d924 --- /dev/null +++ b/internal/pmtud/pmtud_test.go @@ -0,0 +1,55 @@ +package pmtud + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_makeMTUsToTest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + minMTU int + maxMTU int + mtus []int + }{ + "0_0": { + mtus: []int{0}, + }, + "0_1": { + maxMTU: 1, + mtus: []int{0, 1}, + }, + "0_8": { + maxMTU: 8, + mtus: []int{0, 1, 2, 3, 4, 5, 6, 7, 8}, + }, + "0_12": { + maxMTU: 12, + mtus: []int{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12}, + }, + "0_80": { + maxMTU: 80, + mtus: []int{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80}, + }, + "0_100": { + maxMTU: 100, + mtus: []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, + }, + "1280_1500": { + minMTU: 1280, + maxMTU: 1500, + mtus: []int{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU) + assert.Equal(t, testCase.mtus, mtus) + }) + } +} diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 9fed9867..f36bb6ec 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -81,6 +81,7 @@ type Linker interface { LinkDel(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) + LinkSetMTU(link netlink.Link, mtu int) (err error) } type DNSLoop interface { diff --git a/internal/vpn/run.go b/internal/vpn/run.go index e9ab780e..17c5e0a7 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -47,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ + vpnType: settings.Type, serverIP: connection.IP, serverName: connection.ServerName, canPortForward: connection.PortForward, diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 60aac1e0..d42e44ab 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -2,16 +2,24 @@ package vpn import ( "context" + "errors" + "fmt" "net/netip" + "time" "github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/pmtud" "github.com/qdm12/gluetun/internal/version" + "github.com/qdm12/log" ) type tunnelUpData struct { // Healthcheck serverIP netip.Addr + // vpnType is used for path MTU discovery to find the protocol overhead. + // It can be "wireguard" or "openvpn". + vpnType string // Port forwarding vpnIntf string serverName string // used for PIA @@ -31,6 +39,13 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { } } + mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) + err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType, + l.netLinker, l.routing, mtuLogger) + if err != nil { + mtuLogger.Error(err.Error()) + } + icmpTargetIPs := l.healthSettings.ICMPTargetIPs if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() { icmpTargetIPs = []netip.Addr{data.serverIP} @@ -120,3 +135,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) { _, _ = l.ApplyStatus(ctx, constants.Stopped) _, _ = l.ApplyStatus(ctx, constants.Running) } + +var errVPNTypeUnknown = errors.New("unknown VPN type") + +func updateToMaxMTU(ctx context.Context, vpnInterface string, + vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger, +) error { + logger.Info("finding maximum MTU, this can take up to 4 seconds") + + vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface) + if err != nil { + return fmt.Errorf("getting VPN gateway IP address: %w", err) + } + + link, err := netlinker.LinkByName(vpnInterface) + if err != nil { + return fmt.Errorf("getting VPN interface by name: %w", err) + } + + originalMTU := link.MTU + + // Note: no point testing for an MTU of 1500, it will never work due to the VPN + // protocol overhead, so start lower than 1500 according to the protocol used. + const physicalLinkMTU = 1500 + vpnLinkMTU := physicalLinkMTU + switch vpnType { + case "wireguard": + vpnLinkMTU -= 60 // Wireguard overhead + case "openvpn": + vpnLinkMTU -= 41 // OpenVPN overhead + default: + return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType) + } + + // Setting the VPN link MTU to 1500 might interrupt the connection until + // the new MTU is set again, but this is necessary to find the highest valid MTU. + logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU) + + err = netlinker.LinkSetMTU(link, vpnLinkMTU) + if err != nil { + return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) + } + + const pingTimeout = time.Second + vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger) + switch { + case err == nil: + logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) + case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted): + vpnLinkMTU = int(originalMTU) + logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)", + vpnInterface, originalMTU, err) + default: + return fmt.Errorf("path MTU discovering: %w", err) + } + + err = netlinker.LinkSetMTU(link, vpnLinkMTU) + if err != nil { + return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) + } + + return nil +}