Path MTU discovery fixes and improvements (#3109)

- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used
- New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443`
- ICMP PMTUD now targets external-by-default IP addresses
- New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism.
- Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆
- Fix #3108
This commit is contained in:
Quentin McGaw
2026-02-15 01:40:34 +01:00
committed by GitHub
parent 8f1fda7646
commit be92aa2ac4
59 changed files with 2050 additions and 376 deletions
+7
View File
@@ -0,0 +1,7 @@
package tcp
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
@@ -0,0 +1,3 @@
package tcp
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
+80
View File
@@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
// Package tcp is a generated GoMock package.
package tcp
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Debugf mocks base method.
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/test"
)
var ErrMTUNotFound = errors.New("MTU not found")
type testUnit struct {
mtu uint32
ok bool
}
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, logger Logger,
) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
tests := make([]testUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = testUnit{mtu: mtusToTest[i]}
}
family := syscall.AF_INET
if addrPort.Addr().Is6() {
family = syscall.AF_INET6
}
fd, stop, err := startRawSocket(family)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, addrPort.Addr().Is4())
const timeout = time.Second
runCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
errCh := make(chan error)
go func() {
errCh <- tracker.listen(runCtx)
}()
doneCh := make(chan struct{})
for i := range tests {
go func(i int) {
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
tests[i].ok = err == nil
doneCh <- struct{}{}
}(i)
}
for range tests {
select {
case <-doneCh:
case err := <-errCh:
if err == nil { // timeout
break
}
return 0, fmt.Errorf("listening for TCP replies: %w", err)
}
}
if tests[len(tests)-1].ok {
return tests[len(tests)-1].mtu, nil
}
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
stop()
cancel()
return PathMTUDiscover(ctx, addrPort,
tests[i].mtu, tests[i+1].mtu-1, logger)
}
}
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
}
+89
View File
@@ -0,0 +1,89 @@
package tcp
import (
"math/rand/v2"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
// However, in some cases where the server closes the connection with RST immediately,
// it can be useful to add some data payload to a SYN packet and check if the server still
// replies. Only set mtu to a non zero value if you know what you are doing.
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
seq = rand.Uint32() //nolint:gosec
const ack = 0 // SYN has no ACK number
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
}
// createACKPacket creates a TCP ACK packet.
// If the mtu is set to 0, no payload is sent.
// Otherwise, the payload is calculated to test the MTU given.
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
payloadLength := constants.BaseTCPHeaderLength // no data payload
if mtu > 0 {
payloadLength = getPayloadLength(mtu, dst)
}
const flags = ackFlag | pshFlag
return createPacket(src, dst, seq, ack, payloadLength, flags)
}
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
const payloadLength = constants.BaseTCPHeaderLength // no data payload
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
}
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
var ipHeaderLength uint32
if dst.Addr().Is4() {
ipHeaderLength = constants.IPv4HeaderLength
} else {
ipHeaderLength = constants.IPv6HeaderLength
}
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
panic("MTU too small to hold IP and TCP headers")
}
return mtu - ipHeaderLength
}
func createPacket(src, dst netip.AddrPort,
seq, ack, payloadLength uint32, flags byte,
) []byte {
if payloadLength < constants.BaseTCPHeaderLength {
panic("payload length is too small to hold TCP header")
}
var ipHeader []byte
if dst.Addr().Is4() {
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
} else {
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
}
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
// data is just zeroes
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
var data []byte
if dataLength > 0 {
data = make([]byte, dataLength)
}
checksum := tcpChecksum(ipHeader, tcpHeader, data)
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
copy(packet, ipHeader)
copy(packet[len(ipHeader):], tcpHeader)
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
return packet
}
+196
View File
@@ -0,0 +1,196 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
)
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
if err != nil {
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
}
if family == syscall.AF_INET {
err = ip.SetIPv4HeaderIncluded(fdPlatform)
} else {
err = ip.SetIPv6HeaderIncluded(fdPlatform)
}
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
}
// Allow sending packets larger than cached PMTU (for PMTUD probing)
err = setMTUDiscovery(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
}
// use polling because some Linux systems do not cancel
// blocking syscalls such as recvfrom when the socket is closed,
// which would cause things to hang indefinitely.
err = setNonBlock(fdPlatform)
if err != nil {
_ = syscall.Close(fdPlatform)
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
}
stop = func() {
_ = syscall.Close(fdPlatform)
}
return fileDescriptor(fdPlatform), stop, nil
}
var (
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
)
// Craft and send a raw TCP packet to test the MTU.
// It expects either an RST reply (if no server is listening)
// or a SYN-ACK/ACK reply (if a server is listening).
func runTest(ctx context.Context, fd fileDescriptor,
tracker *tracker, dst netip.AddrPort, mtu uint32,
) error {
const proto = syscall.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
if err != nil {
return fmt.Errorf("getting source address: %w", err)
}
defer cleanup()
ch := make(chan []byte)
abort := make(chan struct{})
defer close(abort)
tracker.register(src.Port(), dst.Port(), ch, abort)
defer tracker.unregister(src.Port(), dst.Port())
dstSockAddr := makeSockAddr(dst)
synPacket, synSeq := createSYNPacket(src, dst, 0)
const sendToFlags = 0
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending SYN packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
switch {
case err != nil:
return fmt.Errorf("parsing first reply TCP header: %w", err)
case packetType == packetTypeRST:
// server actively closed the connection, try sending a SYN with data
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
case packetType != packetTypeSYNACK:
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
case synAckAck != synSeq+1:
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
}
// Send a no-data ACK packet to finish the 3-way handshake.
const ackMTU = 0 // no data payload initially
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-without-data packet: %w", err)
}
// Send a data ACK packet to test the MTU given.
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
if err != nil {
return fmt.Errorf("sending ACK-with-data packet: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case reply = <-ch:
}
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing second reply TCP header: %w", err)
}
switch packetType { //nolint:exhaustive
case packetTypeRST:
return nil
case packetTypeACK:
err = sendRST(fd, src, dst, ack)
if err != nil {
return fmt.Errorf("sending RST packet: %w", err)
}
return nil
default:
_ = sendRST(fd, src, dst, ack)
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
}
}
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
if addr.Addr().Is4() {
return &syscall.SockaddrInet4{
Port: int(addr.Port()),
Addr: addr.Addr().As4(),
}
}
return &syscall.SockaddrInet6{
Port: int(addr.Port()),
Addr: addr.Addr().As16(),
}
}
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
src, dst netip.AddrPort, mtu uint32,
) error {
packet, _ := createSYNPacket(src, dst, mtu)
const sendToFlags = 0
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
if err != nil {
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
}
var reply []byte
select {
case <-ctx.Done():
return ctx.Err() // timeout: the MTU test SYN packet was too big
case reply = <-ch:
}
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
if err != nil {
return fmt.Errorf("parsing reply TCP header: %w", err)
} else if packetType != packetTypeRST {
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
}
return nil
}
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
previousACK uint32,
) error {
seq := previousACK
const ack = 0
rstPacket := createRSTPacket(src, dst, seq, ack)
const sendToFlags = 0
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
}
+5
View File
@@ -0,0 +1,5 @@
package tcp
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
return reply, true
}
+7
View File
@@ -0,0 +1,7 @@
package tcp
import "syscall"
func setMTUDiscovery(fd int) error {
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
+30
View File
@@ -0,0 +1,30 @@
//go:build !darwin
package tcp
import (
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
if len(reply) < int(constants.IPv4HeaderLength) {
return nil, false // not an IPv4 packet
}
version := reply[0] >> 4 //nolint:mnd
const ipv4Version = 4
if version != ipv4Version {
return nil, false
}
// For IPv4 we need to skip the IP header, which is at least
// 20B and can be up to 60B.
// The Internet Header Length is the lower 4 bits of the first byte and
// represents the number of 32-bit words of the header length.
const ihlMask byte = 0x0F
const bytesInWord = 4
headerLength := int((reply[0] & ihlMask)) * bytesInWord
if len(reply) < headerLength {
return nil, false // not enough data for full IPv4 header
}
return reply[headerLength:], true
}
+199
View File
@@ -0,0 +1,199 @@
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"syscall"
"testing"
"time"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_runTest(t *testing.T) {
t.Parallel()
noopLogger := &noopLogger{}
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = syscall.AF_INET
fd, stop, err := startRawSocket(family)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
testCases := map[string]struct {
timeout time.Duration
dst func(t *testing.T) netip.AddrPort
mtu uint32
success bool
}{
"local_not_listening": {
timeout: time.Hour,
dst: func(t *testing.T) netip.AddrPort {
t.Helper()
port := reserveClosedPort(t)
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
},
mtu: loopbackMTU,
success: true,
},
"remote_not_listening": {
timeout: 50 * time.Millisecond,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
},
mtu: defaultIPv4MTU,
},
"1.1.1.1:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
"1.1.1.1:80": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
},
mtu: defaultIPv4MTU,
success: true,
},
"8.8.8.8:443": {
timeout: time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
},
mtu: defaultIPv4MTU,
success: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel()
dst := testCase.dst(t)
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
if testCase.success {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
}
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
for _, route := range defaultRoutes {
if route.Family != netlink.FamilyV4 {
continue
}
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
return link.MTU, nil
}
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := syscall.Close(fd)
assert.NoError(t, err)
})
addr := &syscall.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = syscall.Bind(fd, addr)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr, err := syscall.Getsockname(fd)
if err != nil {
_ = syscall.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
if !ok {
_ = syscall.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Error(_ string) {}
+28
View File
@@ -0,0 +1,28 @@
//go:build linux || darwin
package tcp
import (
"syscall"
"time"
)
// fileDescriptor is a platform-independent type for socket file descriptors.
type fileDescriptor int
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(int(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(int(fd), p, flags)
}
func setNonBlock(fd int) error {
return syscall.SetNonblock(fd, true)
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !linux && !windows
package tcp
func setMTUDiscovery(fd int) error {
panic("not implemented")
}
+37
View File
@@ -0,0 +1,37 @@
package tcp
import (
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type fileDescriptor syscall.Handle
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
}
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
timeval := int(timeout.Milliseconds())
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
}
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
}
func setMTUDiscovery(fd syscall.Handle) error {
panic("not implemented")
}
func setNonBlock(fd syscall.Handle) error {
// Windows: Use ioctlsocket with FIONBIO
var arg uint32 = 1 // 1 to enable non-blocking mode
var bytesReturned uint32
const FIONBIO = 0x8004667e
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
}
+124
View File
@@ -0,0 +1,124 @@
package tcp
import (
"encoding/binary"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// For SYN, ack is 0.
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
header := make([]byte, constants.BaseTCPHeaderLength)
binary.BigEndian.PutUint16(header[0:], srcPort)
binary.BigEndian.PutUint16(header[2:], dstPort)
binary.BigEndian.PutUint32(header[4:], seq)
binary.BigEndian.PutUint32(header[8:], ack)
//nolint:mnd
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
header[13] = flags
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
const windowSize = 5840
binary.BigEndian.PutUint16(header[14:], windowSize)
// header[16:17] is the checksum, set later
// header[18:19] is urgent pointer, not needed for our use case
return header
}
//nolint:mnd
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
var pseudoHeader []byte
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
if isIPv6 {
pseudoHeader = make([]byte, 40)
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
pseudoHeader[39] = 6 // Next Header (TCP)
} else {
pseudoHeader = make([]byte, 12)
copy(pseudoHeader[0:4], ipHeader[12:16])
copy(pseudoHeader[4:8], ipHeader[16:20])
pseudoHeader[9] = 6
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
}
sum := uint32(0)
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
for i := 0; i < len(slice)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
}
if len(slice)%2 != 0 {
sum += uint32(slice[len(slice)-1]) << 8
}
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum) //nolint:gosec
}
const (
tcpFlagsOffset = 13
rstFlag byte = 0x04
synFlag byte = 0x02
ackFlag byte = 0x10
pshFlag byte = 0x08
)
type packetType uint8
const (
packetTypeSYN packetType = iota + 1
packetTypeSYNACK
packetTypeACK
packetTypeRST
)
func (p packetType) String() string {
switch p {
case packetTypeSYN:
return "SYN"
case packetTypeSYNACK:
return "SYN-ACK"
case packetTypeACK:
return "ACK"
case packetTypeRST:
return "RST"
default:
panic("unknown packet type")
}
}
var (
errTCPHeaderTooShort = errors.New("TCP header is too short")
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
)
// parseTCPHeader parses some elements from the TCP header.
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
if len(header) < int(constants.BaseTCPHeaderLength) {
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
}
flags := header[tcpFlagsOffset]
switch {
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
packetType = packetTypeSYN
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
packetType = packetTypeSYNACK
case (flags & rstFlag) != 0:
packetType = packetTypeRST
case (flags & ackFlag) != 0:
packetType = packetTypeACK
default:
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
}
seq = binary.BigEndian.Uint32(header[4:8])
ack = binary.BigEndian.Uint32(header[8:12])
return packetType, seq, ack, nil
}
+134
View File
@@ -0,0 +1,134 @@
package tcp
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"syscall"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
type tracker struct {
fd fileDescriptor
ipv4 bool
mutex sync.RWMutex
portsToDispatch map[uint32]dispatch
}
type dispatch struct {
replyCh chan<- []byte
abort <-chan struct{}
}
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
return &tracker{
fd: fd,
ipv4: ipv4,
portsToDispatch: make(map[uint32]dispatch),
}
}
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
buf := make([]byte, 4) //nolint:mnd
binary.BigEndian.PutUint16(buf[0:2], localPort)
binary.BigEndian.PutUint16(buf[2:4], remotePort)
return binary.BigEndian.Uint32(buf)
}
func (t *tracker) register(localPort, remotePort uint16,
ch chan<- []byte, abort <-chan struct{},
) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
t.portsToDispatch[key] = dispatch{
replyCh: ch,
abort: abort,
}
}
func (t *tracker) unregister(localPort, remotePort uint16) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
delete(t.portsToDispatch, key)
}
// listen listens for incoming TCP packets and dispatches them to the
// correct channel based on the source and destination port.
// If the context has a deadline associated, this one is used on the socket.
// Note it returns a nil error on context cancellation.
func (t *tracker) listen(ctx context.Context) error {
deadline, hasDeadline := ctx.Deadline()
for ctx.Err() == nil {
if hasDeadline {
remaining := time.Until(deadline)
if remaining <= 0 {
return nil
}
err := setSocketTimeout(t.fd, remaining)
if err != nil {
return fmt.Errorf("setting socket receive timeout: %w", err)
}
}
reply := make([]byte, constants.MaxEthernetFrameSize)
n, _, err := recvFrom(t.fd, reply, 0)
if err != nil {
switch {
case errors.Is(err, syscall.EAGAIN),
errors.Is(err, syscall.EWOULDBLOCK):
pollSleep(ctx)
continue
case ctx.Err() != nil:
// context canceled, stop listening so exit cleanly with no error
return nil //nolint:nilerr
default:
return fmt.Errorf("receiving on socket: %w", err)
}
}
reply = reply[:n]
if t.ipv4 {
var ok bool
reply, ok = stripIPv4Header(reply)
if !ok {
continue // not an IPv4 packet
}
}
const minTCPHeaderLength = 20
if len(reply) < minTCPHeaderLength {
continue
}
srcPort := binary.BigEndian.Uint16(reply[0:2])
dstPort := binary.BigEndian.Uint16(reply[2:4])
key := t.constructKey(dstPort, srcPort)
t.mutex.RLock()
dispatch, exists := t.portsToDispatch[key]
t.mutex.RUnlock()
if !exists {
continue
}
select {
case dispatch.replyCh <- reply:
case <-dispatch.abort:
}
}
return nil
}
func pollSleep(ctx context.Context) {
const sleepBetweenPolls = 10 * time.Millisecond
timer := time.NewTimer(sleepBetweenPolls)
select {
case <-ctx.Done():
timer.Stop()
case <-timer.C:
}
}