Files
gluetun/internal/pmtud/tcp/tracker.go
T
Quentin McGaw 5f903d1fbf chore(pmtud): remove calls to syscall in favor of unix and windows
- syscall is deprecated and is not kept up-to-date
- each OS is inherently different hence the syscall being deprecated
2026-02-17 21:46:04 +00:00

133 lines
3.0 KiB
Go

package tcp
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"time"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
type tracker struct {
fd fileDescriptor
ipv4 bool
mutex sync.RWMutex
portsToDispatch map[uint32]dispatch
}
type dispatch struct {
replyCh chan<- []byte
abort <-chan struct{}
}
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
return &tracker{
fd: fd,
ipv4: ipv4,
portsToDispatch: make(map[uint32]dispatch),
}
}
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
buf := make([]byte, 4) //nolint:mnd
binary.BigEndian.PutUint16(buf[0:2], localPort)
binary.BigEndian.PutUint16(buf[2:4], remotePort)
return binary.BigEndian.Uint32(buf)
}
func (t *tracker) register(localPort, remotePort uint16,
ch chan<- []byte, abort <-chan struct{},
) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
t.portsToDispatch[key] = dispatch{
replyCh: ch,
abort: abort,
}
}
func (t *tracker) unregister(localPort, remotePort uint16) {
key := t.constructKey(localPort, remotePort)
t.mutex.Lock()
defer t.mutex.Unlock()
delete(t.portsToDispatch, key)
}
// listen listens for incoming TCP packets and dispatches them to the
// correct channel based on the source and destination port.
// If the context has a deadline associated, this one is used on the socket.
// Note it returns a nil error on context cancellation.
func (t *tracker) listen(ctx context.Context) error {
deadline, hasDeadline := ctx.Deadline()
for ctx.Err() == nil {
if hasDeadline {
remaining := time.Until(deadline)
if remaining <= 0 {
return nil
}
err := setSocketTimeout(t.fd, remaining)
if err != nil {
return fmt.Errorf("setting socket receive timeout: %w", err)
}
}
reply := make([]byte, constants.MaxEthernetFrameSize)
n, _, err := recvFrom(t.fd, reply, 0)
if err != nil {
switch {
case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK):
pollSleep(ctx)
continue
case ctx.Err() != nil:
// context canceled, stop listening so exit cleanly with no error
return nil //nolint:nilerr
default:
return fmt.Errorf("receiving on socket: %w", err)
}
}
reply = reply[:n]
if t.ipv4 {
var ok bool
reply, ok = stripIPv4Header(reply)
if !ok {
continue // not an IPv4 packet
}
}
const minTCPHeaderLength = 20
if len(reply) < minTCPHeaderLength {
continue
}
srcPort := binary.BigEndian.Uint16(reply[0:2])
dstPort := binary.BigEndian.Uint16(reply[2:4])
key := t.constructKey(dstPort, srcPort)
t.mutex.RLock()
dispatch, exists := t.portsToDispatch[key]
t.mutex.RUnlock()
if !exists {
continue
}
select {
case dispatch.replyCh <- reply:
case <-dispatch.abort:
}
}
return nil
}
func pollSleep(ctx context.Context) {
const sleepBetweenPolls = 10 * time.Millisecond
timer := time.NewTimer(sleepBetweenPolls)
select {
case <-ctx.Done():
timer.Stop()
case <-timer.C:
}
}