Path MTU discovery fixes and improvements (#3109)

- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used
- New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443`
- ICMP PMTUD now targets external-by-default IP addresses
- New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism.
- Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆
- Fix #3108
This commit is contained in:
Quentin McGaw
2026-02-15 01:40:34 +01:00
committed by GitHub
parent 8f1fda7646
commit be92aa2ac4
59 changed files with 2050 additions and 376 deletions
+49
View File
@@ -0,0 +1,49 @@
package icmp
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}
+83
View File
@@ -0,0 +1,83 @@
package icmp
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrEchoDataMismatch, sent, received)
}
return nil
}
+10
View File
@@ -0,0 +1,10 @@
//go:build !linux && !windows
package icmp
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr) (err error) {
return nil
}
+10
View File
@@ -0,0 +1,10 @@
package icmp
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
+11
View File
@@ -0,0 +1,11 @@
package icmp
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}
+30
View File
@@ -0,0 +1,30 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
+53
View File
@@ -0,0 +1,53 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// PathMTUDiscover discovers the path MTU to the given IP address
// using ICMP.
// It first tries to get the next hop MTU using ICMP messages.
// If that fails, it falls back to sending echo requests with
// different packet sizes to find the maximum MTU.
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := constants.MinIPv4MTU
if ip.Is6() {
minMTU = constants.MinIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
}
+7
View File
@@ -0,0 +1,7 @@
package icmp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
+163
View File
@@ -0,0 +1,163 @@
package icmp
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}
return packetConn, nil
}
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
// for loop in case we read an ICMP message from another ICMP request
// or TCP/UDP traffic triggering an ICMP response.
for {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const portUnreachable = 3
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram
const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
+122
View File
@@ -0,0 +1,122 @@
package icmp
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
return packetConn, nil
}
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
panic("IP address is not v6")
}
conn, err := listenICMPv6(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop if we encounter another ICMP packet with an unknown id.
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
packetBytes = packetBytes[ipv6.HeaderLen:]
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
// Sanity checks
const truncatedBody = true
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
if err != nil {
return 0, fmt.Errorf("checking invoking message: %w", err)
}
return uint32(typedBody.MTU), nil //nolint:gosec
case *icmp.DstUnreach:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
+58
View File
@@ -0,0 +1,58 @@
package icmp
import (
cryptorand "crypto/rand"
"encoding/binary"
"fmt"
"math/rand/v2"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func buildMessageToSend(ipVersion string, mtu uint32) (id uint16, message *icmp.Message) {
var seed [32]byte
_, _ = cryptorand.Read(seed[:])
randomSource := rand.NewChaCha8(seed)
const uint16Bytes = 2
idBytes := make([]byte, uint16Bytes)
_, _ = randomSource.Read(idBytes)
id = binary.BigEndian.Uint16(idBytes)
var ipHeaderLength uint32
var icmpType icmp.Type
switch ipVersion {
case "v4":
ipHeaderLength = ipv4.HeaderLen
icmpType = ipv4.ICMPTypeEcho
case "v6":
ipHeaderLength = ipv6.HeaderLen
icmpType = ipv6.ICMPTypeEchoRequest
default:
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
}
const pingHeaderLength = 0 +
1 + // type
1 + // code
2 + // checksum
2 + // identifier
2 // sequence number
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
messageBodyData := make([]byte, pingBodyDataSize)
_, _ = randomSource.Read(messageBodyData)
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
message = &icmp.Message{
Type: icmpType, // echo request
Code: 0, // no code
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
Body: &icmp.Echo{
ID: int(id),
Seq: 0, // only one packet
Data: messageBodyData,
},
}
return id, message
}
+187
View File
@@ -0,0 +1,187 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
"golang.org/x/net/icmp"
)
type icmpTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
tests := make([]icmpTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
func collectReplies(conn net.PacketConn, ipVersion string,
tests []icmpTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}