mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
feat(pmtud/tcp): support mixed IPv4 and IPv6 TCP servers
- Add default cloudflare and google tls ipv6 servers to default tcp servers - update integration test to try against both ipv4 and ipv6 servers
This commit is contained in:
+1
-1
@@ -114,7 +114,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
WIREGUARD_IMPLEMENTATION=auto \
|
WIREGUARD_IMPLEMENTATION=auto \
|
||||||
# PMTUD
|
# PMTUD
|
||||||
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
||||||
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53 \
|
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443,1.1.1.1:53,8.8.8.8:53,[2606:4700:4700::1111]:53,[2001:4860:4860::8888]:53,[2606:4700:4700::1111]:443,[2001:4860:4860::8888]:443 \
|
||||||
# VPN server filtering
|
# VPN server filtering
|
||||||
SERVER_REGIONS= \
|
SERVER_REGIONS= \
|
||||||
SERVER_COUNTRIES= \
|
SERVER_COUNTRIES= \
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/qdm12/gosettings"
|
"github.com/qdm12/gosettings"
|
||||||
"github.com/qdm12/gosettings/reader"
|
"github.com/qdm12/gosettings/reader"
|
||||||
@@ -70,6 +69,10 @@ func (p *PMTUD) setDefaults() {
|
|||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), dnsPort),
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), dnsPort),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), dnsPort),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), tlsPort),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), tlsPort),
|
||||||
}
|
}
|
||||||
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
||||||
}
|
}
|
||||||
@@ -81,17 +84,15 @@ func (p PMTUD) String() string {
|
|||||||
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
||||||
node = gotree.New("Path MTU discovery:")
|
node = gotree.New("Path MTU discovery:")
|
||||||
|
|
||||||
addrs := make([]string, len(p.ICMPAddresses))
|
icmpAddrNode := node.Append("ICMP addresses:")
|
||||||
for i, addr := range p.ICMPAddresses {
|
for _, addr := range p.ICMPAddresses {
|
||||||
addrs[i] = addr.String()
|
icmpAddrNode.Append(addr.String())
|
||||||
}
|
}
|
||||||
node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", "))
|
|
||||||
|
|
||||||
addrs = make([]string, len(p.TCPAddresses))
|
tcpAddrNode := node.Append("TCP addresses:")
|
||||||
for i, addr := range p.TCPAddresses {
|
for _, addr := range p.TCPAddresses {
|
||||||
addrs[i] = addr.String()
|
tcpAddrNode.Append(addr.String())
|
||||||
}
|
}
|
||||||
node.Appendf("TCP addresses: %s", strings.Join(addrs, ", "))
|
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,8 +38,18 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| | ├── Run OpenVPN as: root
|
| | ├── Run OpenVPN as: root
|
||||||
| | └── Verbosity level: 1
|
| | └── Verbosity level: 1
|
||||||
| └── Path MTU discovery:
|
| └── Path MTU discovery:
|
||||||
| ├── ICMP addresses: 1.1.1.1, 8.8.8.8
|
| ├── ICMP addresses:
|
||||||
| └── TCP addresses: 1.1.1.1:53, 8.8.8.8:53, 1.1.1.1:443, 8.8.8.8:443
|
| | ├── 1.1.1.1
|
||||||
|
| | └── 8.8.8.8
|
||||||
|
| └── TCP addresses:
|
||||||
|
| ├── 1.1.1.1:53
|
||||||
|
| ├── 8.8.8.8:53
|
||||||
|
| ├── 1.1.1.1:443
|
||||||
|
| ├── 8.8.8.8:443
|
||||||
|
| ├── [2606:4700:4700::1111]:53
|
||||||
|
| ├── [2001:4860:4860::8888]:53
|
||||||
|
| ├── [2606:4700:4700::1111]:443
|
||||||
|
| └── [2001:4860:4860::8888]:443
|
||||||
├── DNS settings:
|
├── DNS settings:
|
||||||
| ├── Keep existing nameserver(s): no
|
| ├── Keep existing nameserver(s): no
|
||||||
| ├── DNS server address to use: 127.0.0.1
|
| ├── DNS server address to use: 127.0.0.1
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetFamilies(dsts []netip.AddrPort) (families []int) {
|
||||||
|
const maxFamilies = 2
|
||||||
|
families = make([]int, 0, maxFamilies)
|
||||||
|
for _, dst := range dsts {
|
||||||
|
family := GetFamily(dst)
|
||||||
|
if !slices.Contains(families, family) {
|
||||||
|
families = append(families, family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return families
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFamily(dst netip.AddrPort) int {
|
||||||
|
if dst.Addr().Is4() {
|
||||||
|
return constants.AF_INET
|
||||||
|
}
|
||||||
|
return constants.AF_INET6
|
||||||
|
}
|
||||||
@@ -78,25 +78,31 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
|||||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
noopLogger := &noopLogger{}
|
noopLogger := &noopLogger{}
|
||||||
routing := routing.New(netlinker, noopLogger)
|
routing := routing.New(netlinker, noopLogger)
|
||||||
defaultRoutes, err := routing.DefaultRoutes()
|
defaultRoutes, err := routing.DefaultRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||||
}
|
}
|
||||||
|
families := []uint8{constants.AF_INET, constants.AF_INET6}
|
||||||
|
for _, family := range families {
|
||||||
for _, route := range defaultRoutes {
|
for _, route := range defaultRoutes {
|
||||||
if route.Family != netlink.FamilyV4 {
|
if route.Family != family {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
link, err := netlinker.LinkByName(route.NetInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||||
}
|
}
|
||||||
return link.MTU, nil
|
mtu = max(mtu, link.MTU)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if mtu == 0 {
|
||||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||||
}
|
}
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
// findHighestMSSDestination finds the destination with the highest
|
// findHighestMSSDestination finds the destination with the highest
|
||||||
// MSS amongst the provided destinations.
|
// MSS amongst the provided destinations.
|
||||||
func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
|
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
|
||||||
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
|
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
|
||||||
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
|
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
|
||||||
dst netip.AddrPort, mss uint32, err error,
|
dst netip.AddrPort, mss uint32, err error,
|
||||||
@@ -30,6 +30,7 @@ func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
for _, dst := range dsts {
|
for _, dst := range dsts {
|
||||||
go func(dst netip.AddrPort) {
|
go func(dst netip.AddrPort) {
|
||||||
|
fd := familyToFD[ip.GetFamily(dst)]
|
||||||
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
|
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
|
||||||
resultCh <- result{dst: dst, mss: mss, err: err}
|
resultCh <- result{dst: dst, mss: mss, err: err}
|
||||||
}(dst)
|
}(dst)
|
||||||
|
|||||||
@@ -18,17 +18,16 @@ func Test_findHighestMSSDestination(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
netlinker := netlink.New(&noopLogger{})
|
netlinker := netlink.New(&noopLogger{})
|
||||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
defaultMTU, err := findDefaultRouteMTU(netlinker)
|
||||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
require.NoError(t, err, "finding default route MTU")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(t.Context())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
const family = constants.AF_INET
|
families := []int{constants.AF_INET, constants.AF_INET6}
|
||||||
fd, stop, err := startRawSocket(family, excludeMark)
|
familyToFD, stop, err := startRawSockets(families, excludeMark)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
const ipv4 = true
|
tracker := newTracker(familyToFD)
|
||||||
tracker := newTracker(fd, ipv4)
|
|
||||||
trackerCh := make(chan error)
|
trackerCh := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
trackerCh <- tracker.listen(ctx)
|
trackerCh <- tracker.listen(ctx)
|
||||||
@@ -44,13 +43,15 @@ func Test_findHighestMSSDestination(t *testing.T) {
|
|||||||
dsts := []netip.AddrPort{
|
dsts := []netip.AddrPort{
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
|
||||||
}
|
}
|
||||||
const timeout = time.Second
|
const timeout = time.Second
|
||||||
fw := getFirewall(t)
|
fw := getFirewall(t)
|
||||||
logger := &noopLogger{}
|
logger := &noopLogger{}
|
||||||
|
|
||||||
dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts,
|
dst, mss, err := findHighestMSSDestination(t.Context(), familyToFD, dsts,
|
||||||
excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger)
|
excludeMark, defaultMTU, timeout, tracker, fw, logger)
|
||||||
require.NoError(t, err, "finding highest MSS destination")
|
require.NoError(t, err, "finding highest MSS destination")
|
||||||
assert.Contains(t, dsts, dst, "destination should be in the provided list")
|
assert.Contains(t, dsts, dst, "destination should be in the provided list")
|
||||||
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
|
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
|
||||||
|
|||||||
@@ -33,17 +33,14 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
|||||||
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
|
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
|
||||||
firewall Firewall, logger Logger,
|
firewall Firewall, logger Logger,
|
||||||
) (mtu uint32, err error) {
|
) (mtu uint32, err error) {
|
||||||
family := constants.AF_INET
|
families := ip.GetFamilies(dsts)
|
||||||
if dsts[0].Addr().Is6() {
|
familyToFD, stop, err := startRawSockets(families, excludeMark)
|
||||||
family = constants.AF_INET6
|
|
||||||
}
|
|
||||||
fd, stop, err := startRawSocket(family, excludeMark)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
return 0, fmt.Errorf("starting raw sockets: %w", err)
|
||||||
}
|
}
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
tracker := newTracker(fd, family == constants.AF_INET)
|
tracker := newTracker(familyToFD)
|
||||||
|
|
||||||
trackerCtx, trackerCancel := context.WithCancel(ctx)
|
trackerCtx, trackerCancel := context.WithCancel(ctx)
|
||||||
defer trackerCancel()
|
defer trackerCancel()
|
||||||
@@ -62,7 +59,7 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
|||||||
mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout)
|
mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout)
|
||||||
defer mssCancel()
|
defer mssCancel()
|
||||||
go func() {
|
go func() {
|
||||||
dst, mss, err := findHighestMSSDestination(mssCtx, fd, dsts, excludeMark,
|
dst, mss, err := findHighestMSSDestination(mssCtx, familyToFD, dsts, excludeMark,
|
||||||
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
|
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
|
||||||
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
|
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
|
||||||
}()
|
}()
|
||||||
@@ -83,6 +80,8 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
|||||||
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
|
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fd := familyToFD[ip.GetFamily(highestMSSDst)]
|
||||||
|
|
||||||
type pmtudResult struct {
|
type pmtudResult struct {
|
||||||
mtu uint32
|
mtu uint32
|
||||||
err error
|
err error
|
||||||
|
|||||||
@@ -10,6 +10,29 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func startRawSockets(families []int, excludeMark int) (familyToSocket map[int]fileDescriptor, stop func(), err error) {
|
||||||
|
familyToSocket = make(map[int]fileDescriptor, len(families))
|
||||||
|
stops := make([]func(), 0, len(families))
|
||||||
|
for _, family := range families {
|
||||||
|
fd, stop, err := startRawSocket(family, excludeMark)
|
||||||
|
if err != nil {
|
||||||
|
for _, stop := range stops {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("starting raw socket for family %d: %w", family, err)
|
||||||
|
}
|
||||||
|
stops = append(stops, stop)
|
||||||
|
familyToSocket[family] = fd
|
||||||
|
}
|
||||||
|
|
||||||
|
stop = func() {
|
||||||
|
for _, stop := range stops {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return familyToSocket, stop, nil
|
||||||
|
}
|
||||||
|
|
||||||
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
|
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
|
||||||
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -33,10 +56,10 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
||||||
err = setMTUDiscovery(fdPlatform)
|
err = setMTUDiscovery(fdPlatform, family == constants.AF_INET)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = closeSocket(fdPlatform)
|
_ = closeSocket(fdPlatform)
|
||||||
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
|
return 0, nil, fmt.Errorf("setting MTU discovery options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use polling because some Linux systems do not cancel
|
// use polling because some Linux systems do not cancel
|
||||||
|
|||||||
@@ -18,10 +18,21 @@ import (
|
|||||||
|
|
||||||
func Test_PathMTUDiscover(t *testing.T) {
|
func Test_PathMTUDiscover(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
noopLogger := log.New(log.SetLevel(log.LevelDebug))
|
|
||||||
|
const tryTimeout = time.Second
|
||||||
|
deadline, ok := t.Deadline()
|
||||||
|
if ok {
|
||||||
|
timeLeft := time.Until(deadline)
|
||||||
|
const maxTimeNeeded = tryTimeout * 4 // MSS discovery + 3 MTU tries
|
||||||
|
require.GreaterOrEqual(t, timeLeft, maxTimeNeeded,
|
||||||
|
"not enough time remaining for TCP PMTUD test, need %s and got %s",
|
||||||
|
maxTimeNeeded, timeLeft)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := log.New(log.SetLevel(log.LevelDebug))
|
||||||
|
|
||||||
cmder := command.New()
|
cmder := command.New()
|
||||||
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
fw, err := firewall.NewConfig(t.Context(), logger, cmder, nil, nil)
|
||||||
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||||
}
|
}
|
||||||
@@ -32,11 +43,12 @@ func Test_PathMTUDiscover(t *testing.T) {
|
|||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
|
||||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
|
||||||
|
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
|
||||||
}
|
}
|
||||||
const minMTU = constants.MinIPv6MTU
|
const minMTU = constants.MinIPv6MTU
|
||||||
const maxMTU = constants.MaxEthernetFrameSize
|
const maxMTU = constants.MaxEthernetFrameSize
|
||||||
const tryTimeout = time.Second
|
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, logger)
|
||||||
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger)
|
|
||||||
require.NoError(t, err, "discovering path MTU")
|
require.NoError(t, err, "discovering path MTU")
|
||||||
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
|
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
|
||||||
t.Logf("discovered path MTU is %d", mtu)
|
t.Logf("discovered path MTU is %d", mtu)
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ func setMark(fd, excludeMark int) error {
|
|||||||
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
|
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd int) error {
|
func setMTUDiscovery(fd int, ipv4 bool) error {
|
||||||
|
if ipv4 {
|
||||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||||
}
|
}
|
||||||
|
return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,17 +26,15 @@ func Test_runTest(t *testing.T) {
|
|||||||
netlinker := netlink.New(noopLogger)
|
netlinker := netlink.New(noopLogger)
|
||||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
defaultMTU, err := findDefaultRouteMTU(netlinker)
|
||||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
require.NoError(t, err, "finding default route MTU")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(t.Context())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
const family = constants.AF_INET
|
familyToFD, stop, err := startRawSockets([]int{constants.AF_INET, constants.AF_INET6}, excludeMark)
|
||||||
fd, stop, err := startRawSocket(family, excludeMark)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
const ipv4 = true
|
tracker := newTracker(familyToFD)
|
||||||
tracker := newTracker(fd, ipv4)
|
|
||||||
trackerCh := make(chan error)
|
trackerCh := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
trackerCh <- tracker.listen(ctx)
|
trackerCh <- tracker.listen(ctx)
|
||||||
@@ -71,24 +69,24 @@ func Test_runTest(t *testing.T) {
|
|||||||
"remote_not_listening": {
|
"remote_not_listening": {
|
||||||
timeout: 50 * time.Millisecond,
|
timeout: 50 * time.Millisecond,
|
||||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultMTU - mtuSafetyBuffer,
|
||||||
},
|
},
|
||||||
"1.1.1.1:443": {
|
"1.1.1.1:443": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultMTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"1.1.1.1:80": {
|
"1.1.1.1:80": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultMTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"8.8.8.8:443": {
|
"8.8.8.8:443": {
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
mtu: defaultMTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -99,6 +97,7 @@ func Test_runTest(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
dst := testCase.server
|
dst := testCase.server
|
||||||
|
fd := familyToFD[ip.GetFamily(dst)]
|
||||||
|
|
||||||
const proto = constants.IPPROTO_TCP
|
const proto = constants.IPPROTO_TCP
|
||||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||||
|
|||||||
@@ -6,6 +6,6 @@ func setMark(fd, excludeMark int) error {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd int) error {
|
func setMTUDiscovery(fd int, ipv4 bool) error {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func setMark(fd windows.Handle, _ int) error {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd windows.Handle) error {
|
func setMTUDiscovery(fd windows.Handle, ipv4 bool) error {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type tracker struct {
|
type tracker struct {
|
||||||
fd fileDescriptor
|
familyToFD map[int]fileDescriptor
|
||||||
ipv4 bool
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
portsToDispatch map[uint32]dispatch
|
portsToDispatch map[uint32]dispatch
|
||||||
}
|
}
|
||||||
@@ -23,10 +22,9 @@ type dispatch struct {
|
|||||||
abort <-chan struct{}
|
abort <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
|
func newTracker(familyToFD map[int]fileDescriptor) *tracker {
|
||||||
return &tracker{
|
return &tracker{
|
||||||
fd: fd,
|
familyToFD: familyToFD,
|
||||||
ipv4: ipv4,
|
|
||||||
portsToDispatch: make(map[uint32]dispatch),
|
portsToDispatch: make(map[uint32]dispatch),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -57,11 +55,36 @@ func (t *tracker) unregister(localPort, remotePort uint16) {
|
|||||||
delete(t.portsToDispatch, key)
|
delete(t.portsToDispatch, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen listens for incoming TCP packets and dispatches them to the
|
func (t *tracker) listen(ctx context.Context) (err error) {
|
||||||
// correct channel based on the source and destination port.
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
type result struct {
|
||||||
|
family int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan result)
|
||||||
|
for family, fd := range t.familyToFD {
|
||||||
|
go func(family int, fd fileDescriptor) {
|
||||||
|
err := t.listenFD(ctx, fd, family == constants.AF_INET)
|
||||||
|
resultCh <- result{family: family, err: err}
|
||||||
|
}(family, fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range t.familyToFD {
|
||||||
|
result := <-resultCh
|
||||||
|
if err == nil && result.err != nil {
|
||||||
|
cancel() // stop the other listener if it is still running
|
||||||
|
err = fmt.Errorf("listening for family %d: %w", result.family, result.err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenFD listens for incoming TCP packets on the given file descriptor,
|
||||||
|
// and dispatches them to the correct channel based on the source and destination port.
|
||||||
// If the context has a deadline associated, this one is used on the socket.
|
// If the context has a deadline associated, this one is used on the socket.
|
||||||
// Note it returns a nil error on context cancellation.
|
// Note it returns a nil error on context cancellation.
|
||||||
func (t *tracker) listen(ctx context.Context) error {
|
func (t *tracker) listenFD(ctx context.Context, fd fileDescriptor, ipv4 bool) error {
|
||||||
deadline, hasDeadline := ctx.Deadline()
|
deadline, hasDeadline := ctx.Deadline()
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
if hasDeadline {
|
if hasDeadline {
|
||||||
@@ -69,14 +92,14 @@ func (t *tracker) listen(ctx context.Context) error {
|
|||||||
if remaining <= 0 {
|
if remaining <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err := setSocketTimeout(t.fd, remaining)
|
err := setSocketTimeout(fd, remaining)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting socket receive timeout: %w", err)
|
return fmt.Errorf("setting socket receive timeout: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reply := make([]byte, constants.MaxEthernetFrameSize)
|
reply := make([]byte, constants.MaxEthernetFrameSize)
|
||||||
n, _, err := recvFrom(t.fd, reply, 0)
|
n, _, err := recvFrom(fd, reply, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK):
|
case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK):
|
||||||
@@ -91,7 +114,7 @@ func (t *tracker) listen(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
reply = reply[:n]
|
reply = reply[:n]
|
||||||
|
|
||||||
if t.ipv4 {
|
if ipv4 {
|
||||||
var ok bool
|
var ok bool
|
||||||
reply, ok = stripIPv4Header(reply)
|
reply, ok = stripIPv4Header(reply)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user