feat(pmtud/tcp): support mixed IPv4 and IPv6 TCP servers

- Add default cloudflare and google tls ipv6 servers to default tcp servers
- update integration test to try against both ipv4 and ipv6 servers
This commit is contained in:
Quentin McGaw
2026-02-19 17:11:16 +00:00
parent 1c43a045d1
commit c6b211ef9b
15 changed files with 175 additions and 70 deletions
+1 -1
View File
@@ -114,7 +114,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_IMPLEMENTATION=auto \ WIREGUARD_IMPLEMENTATION=auto \
# PMTUD # PMTUD
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \ PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53 \ PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
# VPN server filtering # VPN server filtering
SERVER_REGIONS= \ SERVER_REGIONS= \
SERVER_COUNTRIES= \ SERVER_COUNTRIES= \
+10 -9
View File
@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader" "github.com/qdm12/gosettings/reader"
@@ -70,6 +69,10 @@ func (p *PMTUD) setDefaults() {
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort), netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
} }
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses) p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
} }
@@ -81,17 +84,15 @@ func (p PMTUD) String() string {
func (p PMTUD) toLinesNode() (node *gotree.Node) { func (p PMTUD) toLinesNode() (node *gotree.Node) {
node = gotree.New("Path MTU discovery:") node = gotree.New("Path MTU discovery:")
addrs := make([]string, len(p.ICMPAddresses)) icmpAddrNode := node.Append("ICMP addresses:")
for i, addr := range p.ICMPAddresses { for _, addr := range p.ICMPAddresses {
addrs[i] = addr.String() icmpAddrNode.Append(addr.String())
} }
node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", "))
addrs = make([]string, len(p.TCPAddresses)) tcpAddrNode := node.Append("TCP addresses:")
for i, addr := range p.TCPAddresses { for _, addr := range p.TCPAddresses {
addrs[i] = addr.String() tcpAddrNode.Append(addr.String())
} }
node.Appendf("TCP addresses: %s", strings.Join(addrs, ", "))
return node return node
} }
@@ -38,8 +38,18 @@ func Test_Settings_String(t *testing.T) {
| | ├── Run OpenVPN as: root | | ├── Run OpenVPN as: root
| | └── Verbosity level: 1 | | └── Verbosity level: 1
| └── Path MTU discovery: | └── Path MTU discovery:
| ├── ICMP addresses: 1.1.1.1, 8.8.8.8 | ├── ICMP addresses:
| └── TCP addresses: 1.1.1.1:53, 8.8.8.8:53, 1.1.1.1:443, 8.8.8.8:443 | | ├── 1.1.1.1
| | └── 8.8.8.8
| └── TCP addresses:
| ├── 1.1.1.1:53
| ├── 8.8.8.8:53
| ├── 1.1.1.1:443
| ├── 8.8.8.8:443
| ├── [2606:4700:4700::1111]:53
| ├── [2001:4860:4860::8888]:53
| ├── [2606:4700:4700::1111]:443
| └── [2001:4860:4860::8888]:443
├── DNS settings: ├── DNS settings:
| ├── Keep existing nameserver(s): no | ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1 | ├── DNS server address to use: 127.0.0.1
+27
View File
@@ -0,0 +1,27 @@
package ip
import (
"net/netip"
"slices"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
func GetFamilies(dsts []netip.AddrPort) (families []int) {
const maxFamilies = 2
families = make([]int, 0, maxFamilies)
for _, dst := range dsts {
family := GetFamily(dst)
if !slices.Contains(families, family) {
families = append(families, family)
}
}
return families
}
func GetFamily(dst netip.AddrPort) int {
if dst.Addr().Is4() {
return constants.AF_INET
}
return constants.AF_INET6
}
+16 -10
View File
@@ -78,24 +78,30 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
} }
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{} noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger) routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes() defaultRoutes, err := routing.DefaultRoutes()
if err != nil { if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err) return 0, fmt.Errorf("getting default routes: %w", err)
} }
for _, route := range defaultRoutes { families := []uint8{constants.AF_INET, constants.AF_INET6}
if route.Family != netlink.FamilyV4 { for _, family := range families {
continue for _, route := range defaultRoutes {
if route.Family != family {
continue
}
link, err := netlinker.LinkByName(route.NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
mtu = max(mtu, link.MTU)
} }
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
return link.MTU, nil
} }
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) if mtu == 0 {
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
return mtu, nil
} }
func reserveClosedPort(t *testing.T) (port uint16) { func reserveClosedPort(t *testing.T) (port uint16) {
+2 -1
View File
@@ -14,7 +14,7 @@ import (
// findHighestMSSDestination finds the destination with the highest // findHighestMSSDestination finds the destination with the highest
// MSS amongst the provided destinations. // MSS amongst the provided destinations.
func findHighestMSSDestination(ctx context.Context, fd fileDescriptor, func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32, dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) ( timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
dst netip.AddrPort, mss uint32, err error, dst netip.AddrPort, mss uint32, err error,
@@ -30,6 +30,7 @@ func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
defer cancel() defer cancel()
for _, dst := range dsts { for _, dst := range dsts {
go func(dst netip.AddrPort) { go func(dst netip.AddrPort) {
fd := familyToFD[ip.GetFamily(dst)]
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger) mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
resultCh <- result{dst: dst, mss: mss, err: err} resultCh <- result{dst: dst, mss: mss, err: err}
}(dst) }(dst)
+9 -8
View File
@@ -18,17 +18,16 @@ func Test_findHighestMSSDestination(t *testing.T) {
t.Parallel() t.Parallel()
netlinker := netlink.New(&noopLogger{}) netlinker := netlink.New(&noopLogger{})
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) defaultMTU, err := findDefaultRouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU") require.NoError(t, err, "finding default route MTU")
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET families := []int{constants.AF_INET, constants.AF_INET6}
fd, stop, err := startRawSocket(family, excludeMark) familyToFD, stop, err := startRawSockets(families, excludeMark)
require.NoError(t, err) require.NoError(t, err)
const ipv4 = true tracker := newTracker(familyToFD)
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error) trackerCh := make(chan error)
go func() { go func() {
trackerCh <- tracker.listen(ctx) trackerCh <- tracker.listen(ctx)
@@ -44,13 +43,15 @@ func Test_findHighestMSSDestination(t *testing.T) {
dsts := []netip.AddrPort{ dsts := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
} }
const timeout = time.Second const timeout = time.Second
fw := getFirewall(t) fw := getFirewall(t)
logger := &noopLogger{} logger := &noopLogger{}
dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts, dst, mss, err := findHighestMSSDestination(t.Context(), familyToFD, dsts,
excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger) excludeMark, defaultMTU, timeout, tracker, fw, logger)
require.NoError(t, err, "finding highest MSS destination") require.NoError(t, err, "finding highest MSS destination")
assert.Contains(t, dsts, dst, "destination should be in the provided list") assert.Contains(t, dsts, dst, "destination should be in the provided list")
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000") assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
+7 -8
View File
@@ -33,17 +33,14 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration, minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger, firewall Firewall, logger Logger,
) (mtu uint32, err error) { ) (mtu uint32, err error) {
family := constants.AF_INET families := ip.GetFamilies(dsts)
if dsts[0].Addr().Is6() { familyToFD, stop, err := startRawSockets(families, excludeMark)
family = constants.AF_INET6
}
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil { if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err) return 0, fmt.Errorf("starting raw sockets: %w", err)
} }
defer stop() defer stop()
tracker := newTracker(fd, family == constants.AF_INET) tracker := newTracker(familyToFD)
trackerCtx, trackerCancel := context.WithCancel(ctx) trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel() defer trackerCancel()
@@ -62,7 +59,7 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout) mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout)
defer mssCancel() defer mssCancel()
go func() { go func() {
dst, mss, err := findHighestMSSDestination(mssCtx, fd, dsts, excludeMark, dst, mss, err := findHighestMSSDestination(mssCtx, familyToFD, dsts, excludeMark,
maxPossibleMTU, tryTimeout, tracker, firewall, logger) maxPossibleMTU, tryTimeout, tracker, firewall, logger)
mssResultCh <- mssResult{dst: dst, mss: mss, err: err} mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
}() }()
@@ -83,6 +80,8 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
} }
fd := familyToFD[ip.GetFamily(highestMSSDst)]
type pmtudResult struct { type pmtudResult struct {
mtu uint32 mtu uint32
err error err error
+25 -2
View File
@@ -10,6 +10,29 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/ip"
) )
func startRawSockets(families []int, excludeMark int) (familyToSocket map[int]fileDescriptor, stop func(), err error) {
familyToSocket = make(map[int]fileDescriptor, len(families))
stops := make([]func(), 0, len(families))
for _, family := range families {
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil {
for _, stop := range stops {
stop()
}
return nil, nil, fmt.Errorf("starting raw socket for family %d: %w", family, err)
}
stops = append(stops, stop)
familyToSocket[family] = fd
}
stop = func() {
for _, stop := range stops {
stop()
}
}
return familyToSocket, stop, nil
}
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) { func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP) fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
if err != nil { if err != nil {
@@ -33,10 +56,10 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
} }
// Allow sending packets larger than cached PMTU (for PMTUD probing) // Allow sending packets larger than cached PMTU (for PMTUD probing)
err = setMTUDiscovery(fdPlatform) err = setMTUDiscovery(fdPlatform, family == constants.AF_INET)
if err != nil { if err != nil {
_ = closeSocket(fdPlatform) _ = closeSocket(fdPlatform)
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err) return 0, nil, fmt.Errorf("setting MTU discovery options: %w", err)
} }
// use polling because some Linux systems do not cancel // use polling because some Linux systems do not cancel
+16 -4
View File
@@ -18,10 +18,21 @@ import (
func Test_PathMTUDiscover(t *testing.T) { func Test_PathMTUDiscover(t *testing.T) {
t.Parallel() t.Parallel()
noopLogger := log.New(log.SetLevel(log.LevelDebug))
const tryTimeout = time.Second
deadline, ok := t.Deadline()
if ok {
timeLeft := time.Until(deadline)
const maxTimeNeeded = tryTimeout * 4 // MSS discovery + 3 MTU tries
require.GreaterOrEqual(t, timeLeft, maxTimeNeeded,
"not enough time remaining for TCP PMTUD test, need %s and got %s",
maxTimeNeeded, timeLeft)
}
logger := log.New(log.SetLevel(log.LevelDebug))
cmder := command.New() cmder := command.New()
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) fw, err := firewall.NewConfig(t.Context(), logger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) { if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests") t.Skip("iptables not installed, skipping TCP PMTUD tests")
} }
@@ -32,11 +43,12 @@ func Test_PathMTUDiscover(t *testing.T) {
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
} }
const minMTU = constants.MinIPv6MTU const minMTU = constants.MinIPv6MTU
const maxMTU = constants.MaxEthernetFrameSize const maxMTU = constants.MaxEthernetFrameSize
const tryTimeout = time.Second mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, logger)
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger)
require.NoError(t, err, "discovering path MTU") require.NoError(t, err, "discovering path MTU")
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0") assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
t.Logf("discovered path MTU is %d", mtu) t.Logf("discovered path MTU is %d", mtu)
+5 -2
View File
@@ -13,6 +13,9 @@ func setMark(fd, excludeMark int) error {
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark) return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
} }
func setMTUDiscovery(fd int) error { func setMTUDiscovery(fd int, ipv4 bool) error {
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) if ipv4 {
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
}
return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
} }
+9 -10
View File
@@ -26,17 +26,15 @@ func Test_runTest(t *testing.T) {
netlinker := netlink.New(noopLogger) netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker) loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU") require.NoError(t, err, "finding loopback IPv4 MTU")
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) defaultMTU, err := findDefaultRouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU") require.NoError(t, err, "finding default route MTU")
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET familyToFD, stop, err := startRawSockets([]int{constants.AF_INET, constants.AF_INET6}, excludeMark)
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err) require.NoError(t, err)
const ipv4 = true tracker := newTracker(familyToFD)
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error) trackerCh := make(chan error)
go func() { go func() {
trackerCh <- tracker.listen(ctx) trackerCh <- tracker.listen(ctx)
@@ -71,24 +69,24 @@ func Test_runTest(t *testing.T) {
"remote_not_listening": { "remote_not_listening": {
timeout: 50 * time.Millisecond, timeout: 50 * time.Millisecond,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345), server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
mtu: defaultIPv4MTU - mtuSafetyBuffer, mtu: defaultMTU - mtuSafetyBuffer,
}, },
"1.1.1.1:443": { "1.1.1.1:443": {
timeout: 5 * time.Second, timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer, mtu: defaultMTU - mtuSafetyBuffer,
success: true, success: true,
}, },
"1.1.1.1:80": { "1.1.1.1:80": {
timeout: 5 * time.Second, timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
mtu: defaultIPv4MTU - mtuSafetyBuffer, mtu: defaultMTU - mtuSafetyBuffer,
success: true, success: true,
}, },
"8.8.8.8:443": { "8.8.8.8:443": {
timeout: 5 * time.Second, timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer, mtu: defaultMTU - mtuSafetyBuffer,
success: true, success: true,
}, },
} }
@@ -99,6 +97,7 @@ func Test_runTest(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
dst := testCase.server dst := testCase.server
fd := familyToFD[ip.GetFamily(dst)]
const proto = constants.IPPROTO_TCP const proto = constants.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto) src, cleanup, err := ip.SrcAddr(dst, proto)
+1 -1
View File
@@ -6,6 +6,6 @@ func setMark(fd, excludeMark int) error {
panic("not implemented") panic("not implemented")
} }
func setMTUDiscovery(fd int) error { func setMTUDiscovery(fd int, ipv4 bool) error {
panic("not implemented") panic("not implemented")
} }
+1 -1
View File
@@ -35,7 +35,7 @@ func setMark(fd windows.Handle, _ int) error {
panic("not implemented") panic("not implemented")
} }
func setMTUDiscovery(fd windows.Handle) error { func setMTUDiscovery(fd windows.Handle, ipv4 bool) error {
panic("not implemented") panic("not implemented")
} }
+34 -11
View File
@@ -12,8 +12,7 @@ import (
) )
type tracker struct { type tracker struct {
fd fileDescriptor familyToFD map[int]fileDescriptor
ipv4 bool
mutex sync.RWMutex mutex sync.RWMutex
portsToDispatch map[uint32]dispatch portsToDispatch map[uint32]dispatch
} }
@@ -23,10 +22,9 @@ type dispatch struct {
abort <-chan struct{} abort <-chan struct{}
} }
func newTracker(fd fileDescriptor, ipv4 bool) *tracker { func newTracker(familyToFD map[int]fileDescriptor) *tracker {
return &tracker{ return &tracker{
fd: fd, familyToFD: familyToFD,
ipv4: ipv4,
portsToDispatch: make(map[uint32]dispatch), portsToDispatch: make(map[uint32]dispatch),
} }
} }
@@ -57,11 +55,36 @@ func (t *tracker) unregister(localPort, remotePort uint16) {
delete(t.portsToDispatch, key) delete(t.portsToDispatch, key)
} }
// listen listens for incoming TCP packets and dispatches them to the func (t *tracker) listen(ctx context.Context) (err error) {
// correct channel based on the source and destination port. ctx, cancel := context.WithCancel(ctx)
defer cancel()
type result struct {
family int
err error
}
resultCh := make(chan result)
for family, fd := range t.familyToFD {
go func(family int, fd fileDescriptor) {
err := t.listenFD(ctx, fd, family == constants.AF_INET)
resultCh <- result{family: family, err: err}
}(family, fd)
}
for range t.familyToFD {
result := <-resultCh
if err == nil && result.err != nil {
cancel() // stop the other listener if it is still running
err = fmt.Errorf("listening for family %d: %w", result.family, result.err)
}
}
return err
}
// listenFD listens for incoming TCP packets on the given file descriptor,
// and dispatches them to the correct channel based on the source and destination port.
// If the context has a deadline associated, this one is used on the socket. // If the context has a deadline associated, this one is used on the socket.
// Note it returns a nil error on context cancellation. // Note it returns a nil error on context cancellation.
func (t *tracker) listen(ctx context.Context) error { func (t *tracker) listenFD(ctx context.Context, fd fileDescriptor, ipv4 bool) error {
deadline, hasDeadline := ctx.Deadline() deadline, hasDeadline := ctx.Deadline()
for ctx.Err() == nil { for ctx.Err() == nil {
if hasDeadline { if hasDeadline {
@@ -69,14 +92,14 @@ func (t *tracker) listen(ctx context.Context) error {
if remaining <= 0 { if remaining <= 0 {
return nil return nil
} }
err := setSocketTimeout(t.fd, remaining) err := setSocketTimeout(fd, remaining)
if err != nil { if err != nil {
return fmt.Errorf("setting socket receive timeout: %w", err) return fmt.Errorf("setting socket receive timeout: %w", err)
} }
} }
reply := make([]byte, constants.MaxEthernetFrameSize) reply := make([]byte, constants.MaxEthernetFrameSize)
n, _, err := recvFrom(t.fd, reply, 0) n, _, err := recvFrom(fd, reply, 0)
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK): case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK):
@@ -91,7 +114,7 @@ func (t *tracker) listen(ctx context.Context) error {
} }
reply = reply[:n] reply = reply[:n]
if t.ipv4 { if ipv4 {
var ok bool var ok bool
reply, ok = stripIPv4Header(reply) reply, ok = stripIPv4Header(reply)
if !ok { if !ok {