mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-09 20:29:23 +02:00
Path MTU discovery fixes and improvements (#3109)
- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used - New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443` - ICMP PMTUD now targets external-by-default IP addresses - New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism. - Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆 - Fix #3108
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
package tcp
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package tcp
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||
@@ -0,0 +1,80 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
|
||||
|
||||
// Package tcp is a generated GoMock package.
|
||||
package tcp
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Warnf mocks base method.
|
||||
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Warnf", varargs...)
|
||||
}
|
||||
|
||||
// Warnf indicates an expected call of Warnf.
|
||||
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||
)
|
||||
|
||||
var ErrMTUNotFound = errors.New("MTU not found")
|
||||
|
||||
type testUnit struct {
|
||||
mtu uint32
|
||||
ok bool
|
||||
}
|
||||
|
||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]testUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = testUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
family := syscall.AF_INET
|
||||
if addrPort.Addr().Is6() {
|
||||
family = syscall.AF_INET6
|
||||
}
|
||||
fd, stop, err := startRawSocket(family)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
tracker := newTracker(fd, addrPort.Addr().Is4())
|
||||
|
||||
const timeout = time.Second
|
||||
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- tracker.listen(runCtx)
|
||||
}()
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
for i := range tests {
|
||||
go func(i int) {
|
||||
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
|
||||
tests[i].ok = err == nil
|
||||
doneCh <- struct{}{}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for range tests {
|
||||
select {
|
||||
case <-doneCh:
|
||||
case err := <-errCh:
|
||||
if err == nil { // timeout
|
||||
break
|
||||
}
|
||||
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if tests[len(tests)-1].ok {
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
}
|
||||
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
stop()
|
||||
cancel()
|
||||
return PathMTUDiscover(ctx, addrPort,
|
||||
tests[i].mtu, tests[i+1].mtu-1, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
|
||||
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
|
||||
// However, in some cases where the server closes the connection with RST immediately,
|
||||
// it can be useful to add some data payload to a SYN packet and check if the server still
|
||||
// replies. Only set mtu to a non zero value if you know what you are doing.
|
||||
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
|
||||
seq = rand.Uint32() //nolint:gosec
|
||||
const ack = 0 // SYN has no ACK number
|
||||
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||
if mtu > 0 {
|
||||
payloadLength = getPayloadLength(mtu, dst)
|
||||
}
|
||||
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
|
||||
}
|
||||
|
||||
// createACKPacket creates a TCP ACK packet.
|
||||
// If the mtu is set to 0, no payload is sent.
|
||||
// Otherwise, the payload is calculated to test the MTU given.
|
||||
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
|
||||
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||
if mtu > 0 {
|
||||
payloadLength = getPayloadLength(mtu, dst)
|
||||
}
|
||||
const flags = ackFlag | pshFlag
|
||||
return createPacket(src, dst, seq, ack, payloadLength, flags)
|
||||
}
|
||||
|
||||
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
|
||||
const payloadLength = constants.BaseTCPHeaderLength // no data payload
|
||||
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
|
||||
}
|
||||
|
||||
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
|
||||
var ipHeaderLength uint32
|
||||
if dst.Addr().Is4() {
|
||||
ipHeaderLength = constants.IPv4HeaderLength
|
||||
} else {
|
||||
ipHeaderLength = constants.IPv6HeaderLength
|
||||
}
|
||||
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
|
||||
panic("MTU too small to hold IP and TCP headers")
|
||||
}
|
||||
return mtu - ipHeaderLength
|
||||
}
|
||||
|
||||
func createPacket(src, dst netip.AddrPort,
|
||||
seq, ack, payloadLength uint32, flags byte,
|
||||
) []byte {
|
||||
if payloadLength < constants.BaseTCPHeaderLength {
|
||||
panic("payload length is too small to hold TCP header")
|
||||
}
|
||||
|
||||
var ipHeader []byte
|
||||
if dst.Addr().Is4() {
|
||||
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
|
||||
} else {
|
||||
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
|
||||
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
|
||||
}
|
||||
|
||||
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
|
||||
|
||||
// data is just zeroes
|
||||
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
|
||||
var data []byte
|
||||
if dataLength > 0 {
|
||||
data = make([]byte, dataLength)
|
||||
}
|
||||
checksum := tcpChecksum(ipHeader, tcpHeader, data)
|
||||
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
|
||||
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
|
||||
|
||||
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
|
||||
copy(packet, ipHeader)
|
||||
copy(packet[len(ipHeader):], tcpHeader)
|
||||
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
|
||||
return packet
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
|
||||
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
||||
}
|
||||
|
||||
if family == syscall.AF_INET {
|
||||
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
||||
} else {
|
||||
err = ip.SetIPv6HeaderIncluded(fdPlatform)
|
||||
}
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
|
||||
}
|
||||
|
||||
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
||||
err = setMTUDiscovery(fdPlatform)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
|
||||
}
|
||||
|
||||
// use polling because some Linux systems do not cancel
|
||||
// blocking syscalls such as recvfrom when the socket is closed,
|
||||
// which would cause things to hang indefinitely.
|
||||
err = setNonBlock(fdPlatform)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
|
||||
}
|
||||
|
||||
stop = func() {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
}
|
||||
return fileDescriptor(fdPlatform), stop, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
|
||||
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
|
||||
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
|
||||
)
|
||||
|
||||
// Craft and send a raw TCP packet to test the MTU.
|
||||
// It expects either an RST reply (if no server is listening)
|
||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||
func runTest(ctx context.Context, fd fileDescriptor,
|
||||
tracker *tracker, dst netip.AddrPort, mtu uint32,
|
||||
) error {
|
||||
const proto = syscall.IPPROTO_TCP
|
||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting source address: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
ch := make(chan []byte)
|
||||
abort := make(chan struct{})
|
||||
defer close(abort)
|
||||
tracker.register(src.Port(), dst.Port(), ch, abort)
|
||||
defer tracker.unregister(src.Port(), dst.Port())
|
||||
|
||||
dstSockAddr := makeSockAddr(dst)
|
||||
|
||||
synPacket, synSeq := createSYNPacket(src, dst, 0)
|
||||
const sendToFlags = 0
|
||||
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending SYN packet: %w", err)
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
switch {
|
||||
case err != nil:
|
||||
return fmt.Errorf("parsing first reply TCP header: %w", err)
|
||||
case packetType == packetTypeRST:
|
||||
// server actively closed the connection, try sending a SYN with data
|
||||
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
||||
case packetType != packetTypeSYNACK:
|
||||
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
|
||||
case synAckAck != synSeq+1:
|
||||
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
|
||||
}
|
||||
|
||||
// Send a no-data ACK packet to finish the 3-way handshake.
|
||||
const ackMTU = 0 // no data payload initially
|
||||
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
|
||||
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending ACK-without-data packet: %w", err)
|
||||
}
|
||||
|
||||
// Send a data ACK packet to test the MTU given.
|
||||
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
|
||||
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending ACK-with-data packet: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing second reply TCP header: %w", err)
|
||||
}
|
||||
|
||||
switch packetType { //nolint:exhaustive
|
||||
case packetTypeRST:
|
||||
return nil
|
||||
case packetTypeACK:
|
||||
err = sendRST(fd, src, dst, ack)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending RST packet: %w", err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
_ = sendRST(fd, src, dst, ack)
|
||||
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
|
||||
}
|
||||
}
|
||||
|
||||
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
|
||||
if addr.Addr().Is4() {
|
||||
return &syscall.SockaddrInet4{
|
||||
Port: int(addr.Port()),
|
||||
Addr: addr.Addr().As4(),
|
||||
}
|
||||
}
|
||||
return &syscall.SockaddrInet6{
|
||||
Port: int(addr.Port()),
|
||||
Addr: addr.Addr().As16(),
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
|
||||
|
||||
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||
src, dst netip.AddrPort, mtu uint32,
|
||||
) error {
|
||||
packet, _ := createSYNPacket(src, dst, mtu)
|
||||
const sendToFlags = 0
|
||||
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // timeout: the MTU test SYN packet was too big
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing reply TCP header: %w", err)
|
||||
} else if packetType != packetTypeRST {
|
||||
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
|
||||
previousACK uint32,
|
||||
) error {
|
||||
seq := previousACK
|
||||
const ack = 0
|
||||
rstPacket := createRSTPacket(src, dst, seq, ack)
|
||||
const sendToFlags = 0
|
||||
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package tcp
|
||||
|
||||
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||
return reply, true
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package tcp
|
||||
|
||||
import "syscall"
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
//go:build !darwin
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||
if len(reply) < int(constants.IPv4HeaderLength) {
|
||||
return nil, false // not an IPv4 packet
|
||||
}
|
||||
|
||||
version := reply[0] >> 4 //nolint:mnd
|
||||
const ipv4Version = 4
|
||||
if version != ipv4Version {
|
||||
return nil, false
|
||||
}
|
||||
// For IPv4 we need to skip the IP header, which is at least
|
||||
// 20B and can be up to 60B.
|
||||
// The Internet Header Length is the lower 4 bits of the first byte and
|
||||
// represents the number of 32-bit words of the header length.
|
||||
const ihlMask byte = 0x0F
|
||||
const bytesInWord = 4
|
||||
headerLength := int((reply[0] & ihlMask)) * bytesInWord
|
||||
if len(reply) < headerLength {
|
||||
return nil, false // not enough data for full IPv4 header
|
||||
}
|
||||
return reply[headerLength:], true
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_runTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noopLogger := &noopLogger{}
|
||||
netlinker := netlink.New(noopLogger)
|
||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = syscall.AF_INET
|
||||
fd, stop, err := startRawSocket(family)
|
||||
require.NoError(t, err)
|
||||
|
||||
const ipv4 = true
|
||||
tracker := newTracker(fd, ipv4)
|
||||
trackerCh := make(chan error)
|
||||
go func() {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
stop()
|
||||
cancel() // stop listening
|
||||
err = <-trackerCh
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
testCases := map[string]struct {
|
||||
timeout time.Duration
|
||||
dst func(t *testing.T) netip.AddrPort
|
||||
mtu uint32
|
||||
success bool
|
||||
}{
|
||||
"local_not_listening": {
|
||||
timeout: time.Hour,
|
||||
dst: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
port := reserveClosedPort(t)
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
|
||||
},
|
||||
mtu: loopbackMTU,
|
||||
success: true,
|
||||
},
|
||||
"remote_not_listening": {
|
||||
timeout: 50 * time.Millisecond,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
},
|
||||
"1.1.1.1:443": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
"1.1.1.1:80": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
"8.8.8.8:443": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||
defer cancel()
|
||||
dst := testCase.dst(t)
|
||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||
if testCase.success {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var errRouteNotFound = errors.New("route not found")
|
||||
|
||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting routes list: %w", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
||||
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by index: %w", err)
|
||||
}
|
||||
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
||||
// or the IP header 16 bits will fail to fit that packet length value.
|
||||
const maxMTU = 65535
|
||||
return min(link.MTU, maxMTU), nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
noopLogger := &noopLogger{}
|
||||
routing := routing.New(netlinker, noopLogger)
|
||||
defaultRoutes, err := routing.DefaultRoutes()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||
}
|
||||
for _, route := range defaultRoutes {
|
||||
if route.Family != netlink.FamilyV4 {
|
||||
continue
|
||||
}
|
||||
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||
}
|
||||
return link.MTU, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||
t.Helper()
|
||||
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := syscall.Close(fd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
addr := &syscall.SockaddrInet4{
|
||||
Port: 0,
|
||||
Addr: [4]byte{127, 0, 0, 1},
|
||||
}
|
||||
|
||||
err = syscall.Bind(fd, addr)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr, err := syscall.Getsockname(fd)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
|
||||
if !ok {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal("not an IPv4 address")
|
||||
}
|
||||
|
||||
return uint16(sockAddr4.Port) //nolint:gosec
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Patch(_ ...log.Option) {}
|
||||
func (l *noopLogger) Debug(_ string) {}
|
||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Info(_ string) {}
|
||||
func (l *noopLogger) Warn(_ string) {}
|
||||
func (l *noopLogger) Error(_ string) {}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fileDescriptor is a platform-independent type for socket file descriptors.
|
||||
type fileDescriptor int
|
||||
|
||||
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||
return syscall.Sendto(int(fd), p, flags, to)
|
||||
}
|
||||
|
||||
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
|
||||
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
|
||||
}
|
||||
|
||||
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||
return syscall.Recvfrom(int(fd), p, flags)
|
||||
}
|
||||
|
||||
func setNonBlock(fd int) error {
|
||||
return syscall.SetNonblock(fd, true)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package tcp
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type fileDescriptor syscall.Handle
|
||||
|
||||
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
|
||||
}
|
||||
|
||||
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||
timeval := int(timeout.Milliseconds())
|
||||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
|
||||
}
|
||||
|
||||
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd syscall.Handle) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setNonBlock(fd syscall.Handle) error {
|
||||
// Windows: Use ioctlsocket with FIONBIO
|
||||
var arg uint32 = 1 // 1 to enable non-blocking mode
|
||||
var bytesReturned uint32
|
||||
const FIONBIO = 0x8004667e
|
||||
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
|
||||
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// For SYN, ack is 0.
|
||||
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
|
||||
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
|
||||
header := make([]byte, constants.BaseTCPHeaderLength)
|
||||
binary.BigEndian.PutUint16(header[0:], srcPort)
|
||||
binary.BigEndian.PutUint16(header[2:], dstPort)
|
||||
binary.BigEndian.PutUint32(header[4:], seq)
|
||||
binary.BigEndian.PutUint32(header[8:], ack)
|
||||
//nolint:mnd
|
||||
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
|
||||
header[13] = flags
|
||||
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
|
||||
const windowSize = 5840
|
||||
binary.BigEndian.PutUint16(header[14:], windowSize)
|
||||
// header[16:17] is the checksum, set later
|
||||
// header[18:19] is urgent pointer, not needed for our use case
|
||||
return header
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
|
||||
var pseudoHeader []byte
|
||||
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
|
||||
if isIPv6 {
|
||||
pseudoHeader = make([]byte, 40)
|
||||
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
|
||||
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
|
||||
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
|
||||
pseudoHeader[39] = 6 // Next Header (TCP)
|
||||
} else {
|
||||
pseudoHeader = make([]byte, 12)
|
||||
copy(pseudoHeader[0:4], ipHeader[12:16])
|
||||
copy(pseudoHeader[4:8], ipHeader[16:20])
|
||||
pseudoHeader[9] = 6
|
||||
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
|
||||
}
|
||||
|
||||
sum := uint32(0)
|
||||
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
|
||||
for i := 0; i < len(slice)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
|
||||
}
|
||||
if len(slice)%2 != 0 {
|
||||
sum += uint32(slice[len(slice)-1]) << 8
|
||||
}
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum) //nolint:gosec
|
||||
}
|
||||
|
||||
const (
|
||||
tcpFlagsOffset = 13
|
||||
rstFlag byte = 0x04
|
||||
synFlag byte = 0x02
|
||||
ackFlag byte = 0x10
|
||||
pshFlag byte = 0x08
|
||||
)
|
||||
|
||||
type packetType uint8
|
||||
|
||||
const (
|
||||
packetTypeSYN packetType = iota + 1
|
||||
packetTypeSYNACK
|
||||
packetTypeACK
|
||||
packetTypeRST
|
||||
)
|
||||
|
||||
func (p packetType) String() string {
|
||||
switch p {
|
||||
case packetTypeSYN:
|
||||
return "SYN"
|
||||
case packetTypeSYNACK:
|
||||
return "SYN-ACK"
|
||||
case packetTypeACK:
|
||||
return "ACK"
|
||||
case packetTypeRST:
|
||||
return "RST"
|
||||
default:
|
||||
panic("unknown packet type")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPHeaderTooShort = errors.New("TCP header is too short")
|
||||
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
|
||||
)
|
||||
|
||||
// parseTCPHeader parses some elements from the TCP header.
|
||||
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
|
||||
if len(header) < int(constants.BaseTCPHeaderLength) {
|
||||
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
|
||||
}
|
||||
flags := header[tcpFlagsOffset]
|
||||
switch {
|
||||
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
|
||||
packetType = packetTypeSYN
|
||||
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
|
||||
packetType = packetTypeSYNACK
|
||||
case (flags & rstFlag) != 0:
|
||||
packetType = packetTypeRST
|
||||
case (flags & ackFlag) != 0:
|
||||
packetType = packetTypeACK
|
||||
default:
|
||||
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
|
||||
}
|
||||
|
||||
seq = binary.BigEndian.Uint32(header[4:8])
|
||||
ack = binary.BigEndian.Uint32(header[8:12])
|
||||
return packetType, seq, ack, nil
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
type tracker struct {
|
||||
fd fileDescriptor
|
||||
ipv4 bool
|
||||
mutex sync.RWMutex
|
||||
portsToDispatch map[uint32]dispatch
|
||||
}
|
||||
|
||||
type dispatch struct {
|
||||
replyCh chan<- []byte
|
||||
abort <-chan struct{}
|
||||
}
|
||||
|
||||
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
|
||||
return &tracker{
|
||||
fd: fd,
|
||||
ipv4: ipv4,
|
||||
portsToDispatch: make(map[uint32]dispatch),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
|
||||
buf := make([]byte, 4) //nolint:mnd
|
||||
binary.BigEndian.PutUint16(buf[0:2], localPort)
|
||||
binary.BigEndian.PutUint16(buf[2:4], remotePort)
|
||||
return binary.BigEndian.Uint32(buf)
|
||||
}
|
||||
|
||||
func (t *tracker) register(localPort, remotePort uint16,
|
||||
ch chan<- []byte, abort <-chan struct{},
|
||||
) {
|
||||
key := t.constructKey(localPort, remotePort)
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
t.portsToDispatch[key] = dispatch{
|
||||
replyCh: ch,
|
||||
abort: abort,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracker) unregister(localPort, remotePort uint16) {
|
||||
key := t.constructKey(localPort, remotePort)
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
delete(t.portsToDispatch, key)
|
||||
}
|
||||
|
||||
// listen listens for incoming TCP packets and dispatches them to the
|
||||
// correct channel based on the source and destination port.
|
||||
// If the context has a deadline associated, this one is used on the socket.
|
||||
// Note it returns a nil error on context cancellation.
|
||||
func (t *tracker) listen(ctx context.Context) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
for ctx.Err() == nil {
|
||||
if hasDeadline {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
err := setSocketTimeout(t.fd, remaining)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting socket receive timeout: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
reply := make([]byte, constants.MaxEthernetFrameSize)
|
||||
n, _, err := recvFrom(t.fd, reply, 0)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, syscall.EAGAIN),
|
||||
errors.Is(err, syscall.EWOULDBLOCK):
|
||||
pollSleep(ctx)
|
||||
continue
|
||||
case ctx.Err() != nil:
|
||||
// context canceled, stop listening so exit cleanly with no error
|
||||
return nil //nolint:nilerr
|
||||
default:
|
||||
return fmt.Errorf("receiving on socket: %w", err)
|
||||
}
|
||||
}
|
||||
reply = reply[:n]
|
||||
|
||||
if t.ipv4 {
|
||||
var ok bool
|
||||
reply, ok = stripIPv4Header(reply)
|
||||
if !ok {
|
||||
continue // not an IPv4 packet
|
||||
}
|
||||
}
|
||||
|
||||
const minTCPHeaderLength = 20
|
||||
if len(reply) < minTCPHeaderLength {
|
||||
continue
|
||||
}
|
||||
|
||||
srcPort := binary.BigEndian.Uint16(reply[0:2])
|
||||
dstPort := binary.BigEndian.Uint16(reply[2:4])
|
||||
key := t.constructKey(dstPort, srcPort)
|
||||
t.mutex.RLock()
|
||||
dispatch, exists := t.portsToDispatch[key]
|
||||
t.mutex.RUnlock()
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case dispatch.replyCh <- reply:
|
||||
case <-dispatch.abort:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pollSleep(ctx context.Context) {
|
||||
const sleepBetweenPolls = 10 * time.Millisecond
|
||||
timer := time.NewTimer(sleepBetweenPolls)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user