Path MTU discovery fixes and improvements (#3109)

- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used
- New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443`
- ICMP PMTUD now targets external-by-default IP addresses
- New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism.
- Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆
- Fix #3108
This commit is contained in:
Quentin McGaw
2026-02-15 01:40:34 +01:00
committed by GitHub
parent 8f1fda7646
commit be92aa2ac4
59 changed files with 2050 additions and 376 deletions
+24
View File
@@ -0,0 +1,24 @@
package constants
const (
MaxEthernetFrameSize uint32 = 1500
// MinIPv4MTU is defined according to
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
MinIPv4MTU uint32 = 68
MinIPv6MTU uint32 = 1280
IPv4HeaderLength uint32 = 20
IPv6HeaderLength uint32 = 40
UDPHeaderLength uint32 = 8
// BaseTCPHeaderLength is the TCP header length without options,
// which is the minimum TCP header length.
BaseTCPHeaderLength uint32 = 20
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
MaxTCPHeaderLength uint32 = 60
WireguardHeaderLength uint32 = 32
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
8 + // session id
4 + // packet id
28 // max possible auth tag/iv
)
-29
View File
@@ -1,29 +0,0 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"net"
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"bytes"
@@ -9,17 +9,17 @@ import (
)
var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
var ErrIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
return nil
}
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
ErrEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
ErrEchoDataMismatch, sent, received)
}
return nil
}
@@ -1,6 +1,6 @@
//go:build !linux && !windows
package pmtud
package icmp
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"syscall"
@@ -1,6 +1,4 @@
//go:build windows
package pmtud
package icmp
import (
"syscall"
+30
View File
@@ -0,0 +1,30 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrNotPermitted = errors.New("ICMP not permitted")
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
ErrMTUNotFound = errors.New("MTU not found")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
+53
View File
@@ -0,0 +1,53 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// PathMTUDiscover discovers the path MTU to the given IP address
// using ICMP.
// It first tries to get the next hop MTU using ICMP messages.
// If that fails, it falls back to sending echo requests with
// different packet sizes to find the maximum MTU.
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := constants.MinIPv4MTU
if ip.Is6() {
minMTU = constants.MinIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
}
+7
View File
@@ -0,0 +1,7 @@
package icmp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"context"
@@ -11,14 +11,13 @@ import (
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU uint32 = 68
icmpv4Protocol int = 1
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
@@ -38,7 +37,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
@@ -83,7 +82,9 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
buffer := make([]byte, physicalLinkMTU)
for { // for loop in case we read an echo reply for another ICMP request
// for loop in case we read an ICMP message from another ICMP request
// or TCP/UDP traffic triggering an ICMP response.
for {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
@@ -108,24 +109,27 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const portUnreachable = 3
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case portUnreachable: // triggered by TCP or UDP from applications
continue // ignore and wait for the next message
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
ErrDestinationUnreachable,
ErrCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
ErrDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
@@ -153,7 +157,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
"context"
@@ -8,12 +8,12 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
@@ -23,7 +23,7 @@ func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return nil, err
}
@@ -85,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
@@ -103,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
@@ -116,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
}
}
}
@@ -1,4 +1,4 @@
package pmtud
package icmp
import (
cryptorand "crypto/rand"
+187
View File
@@ -0,0 +1,187 @@
package icmp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
"golang.org/x/net/icmp"
)
type icmpTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
tests := make([]icmpTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
func collectReplies(conn net.PacketConn, ipVersion string,
tests []icmpTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}
+73
View File
@@ -0,0 +1,73 @@
package ip
import (
"encoding/binary"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
ipHeader := make([]byte, constants.IPv4HeaderLength)
const version byte = 4
const headerLength byte = 20 / 4 // in 32-bit words
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
ipHeader[1] = 0 // type of Service
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
ipHeader[4], ipHeader[5] = 0, 0 // identification
const flagsAndOffset uint16 = 0x4000 // DF bit set
putUint16(ipHeader[6:], flagsAndOffset)
ipHeader[8] = 64 // ttl
ipHeader[9] = syscall.IPPROTO_TCP
srcIPBytes := srcIP.As4()
copy(ipHeader[12:16], srcIPBytes[:])
dstIPBytes := dstIP.As4()
copy(ipHeader[16:20], dstIPBytes[:])
checksum := ipChecksum(ipHeader)
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
return ipHeader
}
// ipChecksum calculates the checksum for the IP header.
//
//nolint:mnd
func ipChecksum(header []byte) uint16 {
sum := uint32(0)
for i := 0; i < len(header)-1; i += 2 {
sum += uint32(header[i])<<8 + uint32(header[i+1])
}
if len(header)%2 != 0 {
sum += uint32(header[len(header)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
// HeaderV6 makes an IPv6 header.
// payloadLen is the length of the payload following the header.
// nextHeader can be byte([syscall.IPPROTO_TCP]) for example.
func HeaderV6(srcIP, dstIP netip.Addr,
payloadLen uint16, nextHeader byte,
) []byte {
ipv6Header := make([]byte, constants.IPv6HeaderLength)
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
// Flow Label (remaining 16 bits)
ipv6Header[2] = 0x00
ipv6Header[3] = 0x00
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
ipv6Header[6] = nextHeader
const hopLimit = 64
ipv6Header[7] = hopLimit
copy(ipv6Header[8:24], srcIP.AsSlice())
copy(ipv6Header[24:40], dstIP.AsSlice())
return ipv6Header
}
+9
View File
@@ -0,0 +1,9 @@
package ip
import (
"encoding/binary"
)
func putUint16(b []byte, v uint16) {
binary.NativeEndian.PutUint16(b, v)
}
@@ -0,0 +1,9 @@
//go:build !darwin
package ip
import "encoding/binary"
func putUint16(b []byte, v uint16) {
binary.BigEndian.PutUint16(b, v)
}
+9
View File
@@ -0,0 +1,9 @@
//go:build linux || darwin
package ip
import "syscall"
func SetIPv4HeaderIncluded(fd int) error {
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows && !darwin
package ip
func SetIPv4HeaderIncluded(fd int) error {
panic("not implemented")
}
+12
View File
@@ -0,0 +1,12 @@
package ip
import (
"syscall"
"golang.org/x/sys/windows"
)
func SetIPv4HeaderIncluded(handle syscall.Handle) error {
const ipHdrIncluded = windows.IP_HDRINCL
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, ipHdrIncluded, 1)
}
+5
View File
@@ -0,0 +1,5 @@
package ip
func SetIPv6HeaderIncluded(fd int) error {
panic("darwin does not allow an application to build IPv6 headers")
}
+8
View File
@@ -0,0 +1,8 @@
package ip
import "syscall"
func SetIPv6HeaderIncluded(fd int) error {
const ipv6HdrIncluded = 36 // IPV6_HDRINCL
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, ipv6HdrIncluded, 1)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows && !darwin
package ip
func SetIPv6HeaderIncluded(fd int) error {
panic("not implemented")
}
+7
View File
@@ -0,0 +1,7 @@
package ip
import "syscall"
func SetIPv6HeaderIncluded(fd syscall.Handle) error {
panic("windows does not allow an application to build IPv6 headers")
}
+123
View File
@@ -0,0 +1,123 @@
package ip
import (
"fmt"
"net/netip"
"syscall"
"github.com/jsimonetti/rtnetlink"
)
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
// specified destination. It also reserves an ephemeral source port for the specified protocol
// to ensure that the port is not used by other processes. The cleanup function returned should
// be called to release the reserved port when done.
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
srcAddr, err := srcIP(dst.Addr())
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
}
srcPort, cleanup, err := srcPort(srcAddr, proto)
if err != nil {
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
}
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
}
var errNoRoute = fmt.Errorf("no route to destination")
func srcIP(dst netip.Addr) (netip.Addr, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return netip.Addr{}, err
}
defer conn.Close()
family := uint8(syscall.AF_INET)
if dst.Is6() {
family = syscall.AF_INET6
}
// Request route to destination
requestMessage := &rtnetlink.RouteMessage{
Family: family,
Attributes: rtnetlink.RouteAttributes{
Dst: dst.AsSlice(),
},
}
messages, err := conn.Route.Get(requestMessage)
if err != nil {
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
}
for _, message := range messages {
if message.Attributes.Src == nil {
continue
}
ipv6 := message.Attributes.Src.To4() == nil
if ipv6 {
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
}
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
}
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
}
// srcPort reserves an ephemeral source port by opening a socket for the
// protocol specified and binds it to the provided source address.
// It doesn't actually listen on the port.
// The cleanup function returned should be called to release the port when done.
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
family := syscall.AF_INET
if srcAddr.Is6() {
family = syscall.AF_INET6
}
fd, err := syscall.Socket(family, syscall.SOCK_STREAM, proto)
if err != nil {
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
}
cleanup = func() {
_ = syscall.Close(fd)
}
// Bind to port 0 to get an ephemeral port
const port = 0
var bindAddr syscall.Sockaddr
if srcAddr.Is4() {
bindAddr = &syscall.SockaddrInet4{
Port: port,
Addr: srcAddr.As4(),
}
} else {
bindAddr = &syscall.SockaddrInet6{
Port: port,
Addr: srcAddr.As16(),
}
}
err = syscall.Bind(fd, bindAddr)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
}
sockAddr, err := syscall.Getsockname(fd)
if err != nil {
cleanup()
return 0, nil, fmt.Errorf("getting bound socket name: %w", err)
}
switch typedSockAddr := sockAddr.(type) {
case *syscall.SockaddrInet4:
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
case *syscall.SockaddrInet6:
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
default:
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
}
return srcPort, cleanup, nil
}
+38 -233
View File
@@ -4,268 +4,73 @@ import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/icmp"
"github.com/qdm12/gluetun/internal/pmtud/tcp"
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
// whole range of possible MTU values up to the physical link MTU to check.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) (
mtu uint32, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if pingTimeout == 0 {
pingTimeout = time.Second
if tryTimeout == 0 {
tryTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
// Try finding the MTU using ICMP
maxPossibleMTU := physicalLinkMTU
icmpSuccess := false
for _, icmpIP := range icmpAddrs {
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
tryTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
icmpSuccess = true
maxPossibleMTU = mtu
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
logger.Debugf("ICMP path MTU discovery failed: %s", err)
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
for _, addrPort := range tcpAddrs {
minMTU := constants.MinIPv4MTU
if addrPort.Addr().Is6() {
minMTU = constants.MinIPv6MTU
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if icmpSuccess {
const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
}
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
return mtu, nil
}
return nil
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
}
+7
View File
@@ -0,0 +1,7 @@
package tcp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
@@ -0,0 +1,3 @@
package tcp
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
+80
View File
@@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
// Package tcp is a generated GoMock package.
package tcp
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Debugf mocks base method.
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
var ErrMTUNotFound = errors.New("MTU not found")
type testUnit struct {
mtu uint32
ok bool
}
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, logger Logger,
) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
tests := make([]testUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = testUnit{mtu: mtusToTest[i]}
}
family := syscall.AF_INET
if addrPort.Addr().Is6() {
family = syscall.AF_INET6
}
fd, stop, err := startRawSocket(family)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, addrPort.Addr().Is4())
const timeout = time.Second
runCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
errCh := make(chan error)
go func() {
errCh <- tracker.listen(runCtx)
}()
doneCh := make(chan struct{})
for i := range tests {
go func(i int) {
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
tests[i].ok = err == nil
doneCh <- struct{}{}
}(i)
}
for range tests {
select {
case <-doneCh:
case err := <-errCh:
if err == nil { // timeout
break
}
return 0, fmt.Errorf("listening for TCP replies: %w", err)
}
}
if tests[len(tests)-1].ok {
return tests[len(tests)-1].mtu, nil
}
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
stop()
cancel()
return PathMTUDiscover(ctx, addrPort,
tests[i].mtu, tests[i+1].mtu-1, logger)
}
}
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"math/rand/v2"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
// However, in some cases where the server closes the connection with RST immediately,
// it can be useful to add some data payload to a SYN packet and check if the server still
// replies. Only set mtu to a non zero value if you know what you are doing.
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
seq = rand.Uint32() //nolint:gosec
const ack = 0 // SYN has no ACK number
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
}
// createACKPacket creates a TCP ACK packet.
// If the mtu is set to 0, no payload is sent.
// Otherwise, the payload is calculated to test the MTU given.
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
const flags = ackFlag | pshFlag
return createPacket(src, dst, seq, ack, payloadLength, flags)
}
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
const payloadLength = constants.BaseTCPHeaderLength // no data payload
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
}
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
var ipHeaderLength uint32
if dst.Addr().Is4() {
ipHeaderLength = constants.IPv4HeaderLength
} else {
ipHeaderLength = constants.IPv6HeaderLength
}
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
panic("MTU too small to hold IP and TCP headers")
}
return mtu - ipHeaderLength
}
func createPacket(src, dst netip.AddrPort,
seq, ack, payloadLength uint32, flags byte,
) []byte {
if payloadLength < constants.BaseTCPHeaderLength {
panic("payload length is too small to hold TCP header")
}
var ipHeader []byte
if dst.Addr().Is4() {
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
} else {
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
}
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
// data is just zeroes
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
var data []byte
if dataLength > 0 {
data = make([]byte, dataLength)
}
checksum := tcpChecksum(ipHeader, tcpHeader, data)
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
copy(packet, ipHeader)
copy(packet[len(ipHeader):], tcpHeader)
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
return packet
}
+196
View File
@@ -0,0 +1,196 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
if err != nil {
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
}
if family == syscall.AF_INET {
err = ip.SetIPv4HeaderIncluded(fdPlatform)
} else {
err = ip.SetIPv6HeaderIncluded(fdPlatform)
}
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
}
// Allow sending packets larger than cached PMTU (for PMTUD probing)
err = setMTUDiscovery(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
}
// use polling because some Linux systems do not cancel
// blocking syscalls such as recvfrom when the socket is closed,
// which would cause things to hang indefinitely.
err = setNonBlock(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
}
stop = func() {
_ = syscall.Close(fdPlatform)
}
return fileDescriptor(fdPlatform), stop, nil
}
var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
)
// Craft and send a raw TCP packet to test the MTU.
// It expects either an RST reply (if no server is listening)
// or a SYN-ACK/ACK reply (if a server is listening).
func runTest(ctx context.Context, fd fileDescriptor,
tracker *tracker, dst netip.AddrPort, mtu uint32,
) error {
const proto = syscall.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
if err != nil {
return fmt.Errorf("getting source address: %w", err)
}
defer cleanup()
ch := make(chan []byte)
abort := make(chan struct{})
defer close(abort)
tracker.register(src.Port(), dst.Port(), ch, abort)
defer tracker.unregister(src.Port(), dst.Port())
dstSockAddr := makeSockAddr(dst)
synPacket, synSeq := createSYNPacket(src, dst, 0)
const sendToFlags = 0
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending SYN packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
switch {
case err != nil:
return fmt.Errorf("parsing first reply TCP header: %w", err)
case packetType == packetTypeRST:
// server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case packetType != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
case synAckAck != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
}
// Send a no-data ACK packet to finish the 3-way handshake.
const ackMTU = 0 // no data payload initially
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-without-data packet: %w", err)
}
// Send a data ACK packet to test the MTU given.
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-with-data packet: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing second reply TCP header: %w", err)
}
switch packetType { //nolint:exhaustive
case packetTypeRST:
return nil
case packetTypeACK:
err = sendRST(fd, src, dst, ack)
if err != nil {
return fmt.Errorf("sending RST packet: %w", err)
}
return nil
default:
_ = sendRST(fd, src, dst, ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
}
}
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
if addr.Addr().Is4() {
return &syscall.SockaddrInet4{
Port: int(addr.Port()),
Addr: addr.Addr().As4(),
}
}
return &syscall.SockaddrInet6{
Port: int(addr.Port()),
Addr: addr.Addr().As16(),
}
}
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
src, dst netip.AddrPort, mtu uint32,
) error {
packet, _ := createSYNPacket(src, dst, mtu)
const sendToFlags = 0
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
if err != nil {
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err() // timeout: the MTU test SYN packet was too big
case reply = <-ch:
}
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing reply TCP header: %w", err)
} else if packetType != packetTypeRST {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
}
return nil
}
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
previousACK uint32,
) error {
seq := previousACK
const ack = 0
rstPacket := createRSTPacket(src, dst, seq, ack)
const sendToFlags = 0
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
}
+5
View File
@@ -0,0 +1,5 @@
package tcp
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
return reply, true
}
+7
View File
@@ -0,0 +1,7 @@
package tcp
import "syscall"
func setMTUDiscovery(fd int) error {
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
+30
View File
@@ -0,0 +1,30 @@
//go:build !darwin
package tcp
import (
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
if len(reply) < int(constants.IPv4HeaderLength) {
return nil, false // not an IPv4 packet
}
version := reply[0] >> 4 //nolint:mnd
const ipv4Version = 4
if version != ipv4Version {
return nil, false
}
// For IPv4 we need to skip the IP header, which is at least
// 20B and can be up to 60B.
// The Internet Header Length is the lower 4 bits of the first byte and
// represents the number of 32-bit words of the header length.
const ihlMask byte = 0x0F
const bytesInWord = 4
headerLength := int((reply[0] & ihlMask)) * bytesInWord
if len(reply) < headerLength {
return nil, false // not enough data for full IPv4 header
}
return reply[headerLength:], true
}
+199
View File
@@ -0,0 +1,199 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"testing"
"time"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_runTest(t *testing.T) {
t.Parallel()
noopLogger := &noopLogger{}
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = syscall.AF_INET
fd, stop, err := startRawSocket(family)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
testCases := map[string]struct {
timeout time.Duration
dst func(t *testing.T) netip.AddrPort
mtu uint32
success bool
}{
"local_not_listening": {
timeout: time.Hour,
dst: func(t *testing.T) netip.AddrPort {
t.Helper()
port := reserveClosedPort(t)
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
},
mtu: loopbackMTU,
success: true,
},
"remote_not_listening": {
timeout: 50 * time.Millisecond,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
},
mtu: defaultIPv4MTU,
},
"1.1.1.1:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
"1.1.1.1:80": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
},
mtu: defaultIPv4MTU,
success: true,
},
"8.8.8.8:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel()
dst := testCase.dst(t)
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
if testCase.success {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
}
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
for _, route := range defaultRoutes {
if route.Family != netlink.FamilyV4 {
continue
}
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
return link.MTU, nil
}
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := syscall.Close(fd)
assert.NoError(t, err)
})
addr := &syscall.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = syscall.Bind(fd, addr)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr, err := syscall.Getsockname(fd)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
if !ok {
_ = syscall.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Error(_ string) {}
+28
View File
@@ -0,0 +1,28 @@
//go:build linux || darwin
package tcp
import (
"syscall"
"time"
)
// fileDescriptor is a platform-independent type for socket file descriptors.
type fileDescriptor int
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(int(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(int(fd), p, flags)
}
func setNonBlock(fd int) error {
return syscall.SetNonblock(fd, true)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows
package tcp
func setMTUDiscovery(fd int) error {
panic("not implemented")
}
+37
View File
@@ -0,0 +1,37 @@
package tcp
import (
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type fileDescriptor syscall.Handle
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := int(timeout.Milliseconds())
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
}
func setMTUDiscovery(fd syscall.Handle) error {
panic("not implemented")
}
func setNonBlock(fd syscall.Handle) error {
// Windows: Use ioctlsocket with FIONBIO
var arg uint32 = 1 // 1 to enable non-blocking mode
var bytesReturned uint32
const FIONBIO = 0x8004667e
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
}
+124
View File
@@ -0,0 +1,124 @@
package tcp
import (
"encoding/binary"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// For SYN, ack is 0.
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
header := make([]byte, constants.BaseTCPHeaderLength)
binary.BigEndian.PutUint16(header[0:], srcPort)
binary.BigEndian.PutUint16(header[2:], dstPort)
binary.BigEndian.PutUint32(header[4:], seq)
binary.BigEndian.PutUint32(header[8:], ack)
//nolint:mnd
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
header[13] = flags
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
const windowSize = 5840
binary.BigEndian.PutUint16(header[14:], windowSize)
// header[16:17] is the checksum, set later
// header[18:19] is urgent pointer, not needed for our use case
return header
}
//nolint:mnd
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
var pseudoHeader []byte
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
if isIPv6 {
pseudoHeader = make([]byte, 40)
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
pseudoHeader[39] = 6 // Next Header (TCP)
} else {
pseudoHeader = make([]byte, 12)
copy(pseudoHeader[0:4], ipHeader[12:16])
copy(pseudoHeader[4:8], ipHeader[16:20])
pseudoHeader[9] = 6
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
}
sum := uint32(0)
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
for i := 0; i < len(slice)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
}
if len(slice)%2 != 0 {
sum += uint32(slice[len(slice)-1]) << 8
}
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
const (
tcpFlagsOffset = 13
rstFlag byte = 0x04
synFlag byte = 0x02
ackFlag byte = 0x10
pshFlag byte = 0x08
)
type packetType uint8
const (
packetTypeSYN packetType = iota + 1
packetTypeSYNACK
packetTypeACK
packetTypeRST
)
func (p packetType) String() string {
switch p {
case packetTypeSYN:
return "SYN"
case packetTypeSYNACK:
return "SYN-ACK"
case packetTypeACK:
return "ACK"
case packetTypeRST:
return "RST"
default:
panic("unknown packet type")
}
}
var (
errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
)
// parseTCPHeader parses some elements from the TCP header.
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
if len(header) < int(constants.BaseTCPHeaderLength) {
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
}
flags := header[tcpFlagsOffset]
switch {
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
packetType = packetTypeSYN
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
packetType = packetTypeSYNACK
case (flags & rstFlag) != 0:
packetType = packetTypeRST
case (flags & ackFlag) != 0:
packetType = packetTypeACK
default:
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
}
seq = binary.BigEndian.Uint32(header[4:8])
ack = binary.BigEndian.Uint32(header[8:12])
return packetType, seq, ack, nil
}
+134
View File
@@ -0,0 +1,134 @@
package tcp
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
type tracker struct {
fd fileDescriptor
ipv4 bool
mutex sync.RWMutex
portsToDispatch map[uint32]dispatch
}
type dispatch struct {
replyCh chan<- []byte
abort <-chan struct{}
}
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
return &tracker{
fd: fd,
ipv4: ipv4,
portsToDispatch: make(map[uint32]dispatch),
}
}
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
buf := make([]byte, 4) //nolint:mnd
binary.BigEndian.PutUint16(buf[0:2], localPort)
binary.BigEndian.PutUint16(buf[2:4], remotePort)
return binary.BigEndian.Uint32(buf)
}
func (t *tracker) register(localPort, remotePort uint16,
ch chan<- []byte, abort <-chan struct{},
) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
t.portsToDispatch[key] = dispatch{
replyCh: ch,
abort: abort,
}
}
func (t *tracker) unregister(localPort, remotePort uint16) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
delete(t.portsToDispatch, key)
}
// listen listens for incoming TCP packets and dispatches them to the
// correct channel based on the source and destination port.
// If the context has a deadline associated, this one is used on the socket.
// Note it returns a nil error on context cancellation.
func (t *tracker) listen(ctx context.Context) error {
deadline, hasDeadline := ctx.Deadline()
for ctx.Err() == nil {
if hasDeadline {
remaining := time.Until(deadline)
if remaining <= 0 {
return nil
}
err := setSocketTimeout(t.fd, remaining)
if err != nil {
return fmt.Errorf("setting socket receive timeout: %w", err)
}
}
reply := make([]byte, constants.MaxEthernetFrameSize)
n, _, err := recvFrom(t.fd, reply, 0)
if err != nil {
switch {
case errors.Is(err, syscall.EAGAIN),
errors.Is(err, syscall.EWOULDBLOCK):
pollSleep(ctx)
continue
case ctx.Err() != nil:
// context canceled, stop listening so exit cleanly with no error
return nil //nolint:nilerr
default:
return fmt.Errorf("receiving on socket: %w", err)
}
}
reply = reply[:n]
if t.ipv4 {
var ok bool
reply, ok = stripIPv4Header(reply)
if !ok {
continue // not an IPv4 packet
}
}
const minTCPHeaderLength = 20
if len(reply) < minTCPHeaderLength {
continue
}
srcPort := binary.BigEndian.Uint16(reply[0:2])
dstPort := binary.BigEndian.Uint16(reply[2:4])
key := t.constructKey(dstPort, srcPort)
t.mutex.RLock()
dispatch, exists := t.portsToDispatch[key]
t.mutex.RUnlock()
if !exists {
continue
}
select {
case dispatch.replyCh <- reply:
case <-dispatch.abort:
}
}
return nil
}
func pollSleep(ctx context.Context) {
const sleepBetweenPolls = 10 * time.Millisecond
timer := time.NewTimer(sleepBetweenPolls)
select {
case <-ctx.Done():
timer.Stop()
case <-timer.C:
}
}
+36
View File
@@ -0,0 +1,36 @@
package test
import "math"
// MakeMTUsToTest determines a slice of MTU values to test
// between minMTU and maxMTU inclusive. It creates an MTU
// slice of length up to 11 MTUs such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func MakeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
@@ -1,4 +1,4 @@
package pmtud
package test
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_makeMTUsToTest(t *testing.T) {
func Test_MakeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -48,7 +48,7 @@ func Test_makeMTUsToTest(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
+40
View File
@@ -0,0 +1,40 @@
package pmtud
import (
"net/netip"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
)
// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel
// given the VPN type, network protocol, and VPN gateway IP address.
// This is notably useful to skip testing MTU values higher than this value.
// The function panics if the network or VPN type is unknown.
func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 {
const physicalLinkMTU = pconstants.MaxEthernetFrameSize
vpnLinkMTU := physicalLinkMTU
if vpnGateway.Is4() {
vpnLinkMTU -= pconstants.IPv4HeaderLength
} else {
vpnLinkMTU -= pconstants.IPv6HeaderLength
}
switch network {
case constants.TCP:
vpnLinkMTU -= pconstants.BaseTCPHeaderLength
case constants.UDP:
vpnLinkMTU -= pconstants.UDPHeaderLength
default:
panic("unknown network protocol: " + network)
}
switch vpnType {
case vpn.Wireguard:
vpnLinkMTU -= pconstants.WireguardHeaderLength
case vpn.OpenVPN:
vpnLinkMTU -= pconstants.OpenVPNHeaderMaxLength
default:
panic("unknown VPN type: " + vpnType)
}
return vpnLinkMTU
}