mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
feat(pmtud/tcp): support mixed IPv4 and IPv6 TCP servers
- Add default cloudflare and google tls ipv6 servers to default tcp servers - update integration test to try against both ipv4 and ipv6 servers
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func GetFamilies(dsts []netip.AddrPort) (families []int) {
|
||||
const maxFamilies = 2
|
||||
families = make([]int, 0, maxFamilies)
|
||||
for _, dst := range dsts {
|
||||
family := GetFamily(dst)
|
||||
if !slices.Contains(families, family) {
|
||||
families = append(families, family)
|
||||
}
|
||||
}
|
||||
return families
|
||||
}
|
||||
|
||||
func GetFamily(dst netip.AddrPort) int {
|
||||
if dst.Addr().Is4() {
|
||||
return constants.AF_INET
|
||||
}
|
||||
return constants.AF_INET6
|
||||
}
|
||||
@@ -78,24 +78,30 @@ func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
noopLogger := &noopLogger{}
|
||||
routing := routing.New(netlinker, noopLogger)
|
||||
defaultRoutes, err := routing.DefaultRoutes()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||
}
|
||||
for _, route := range defaultRoutes {
|
||||
if route.Family != netlink.FamilyV4 {
|
||||
continue
|
||||
families := []uint8{constants.AF_INET, constants.AF_INET6}
|
||||
for _, family := range families {
|
||||
for _, route := range defaultRoutes {
|
||||
if route.Family != family {
|
||||
continue
|
||||
}
|
||||
link, err := netlinker.LinkByName(route.NetInterface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||
}
|
||||
mtu = max(mtu, link.MTU)
|
||||
}
|
||||
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||
}
|
||||
return link.MTU, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
if mtu == 0 {
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
}
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// findHighestMSSDestination finds the destination with the highest
|
||||
// MSS amongst the provided destinations.
|
||||
func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
|
||||
func findHighestMSSDestination(ctx context.Context, familyToFD map[int]fileDescriptor,
|
||||
dsts []netip.AddrPort, excludeMark int, maxPossibleMTU uint32,
|
||||
timeout time.Duration, tracker *tracker, fw Firewall, logger Logger) (
|
||||
dst netip.AddrPort, mss uint32, err error,
|
||||
@@ -30,6 +30,7 @@ func findHighestMSSDestination(ctx context.Context, fd fileDescriptor,
|
||||
defer cancel()
|
||||
for _, dst := range dsts {
|
||||
go func(dst netip.AddrPort) {
|
||||
fd := familyToFD[ip.GetFamily(dst)]
|
||||
mss, err := findMSS(ctx, fd, dst, excludeMark, tracker, fw, logger)
|
||||
resultCh <- result{dst: dst, mss: mss, err: err}
|
||||
}(dst)
|
||||
|
||||
@@ -18,17 +18,16 @@ func Test_findHighestMSSDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := netlink.New(&noopLogger{})
|
||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
||||
defaultMTU, err := findDefaultRouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default route MTU")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = constants.AF_INET
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
families := []int{constants.AF_INET, constants.AF_INET6}
|
||||
familyToFD, stop, err := startRawSockets(families, excludeMark)
|
||||
require.NoError(t, err)
|
||||
|
||||
const ipv4 = true
|
||||
tracker := newTracker(fd, ipv4)
|
||||
tracker := newTracker(familyToFD)
|
||||
trackerCh := make(chan error)
|
||||
go func() {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
@@ -44,13 +43,15 @@ func Test_findHighestMSSDestination(t *testing.T) {
|
||||
dsts := []netip.AddrPort{
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
|
||||
}
|
||||
const timeout = time.Second
|
||||
fw := getFirewall(t)
|
||||
logger := &noopLogger{}
|
||||
|
||||
dst, mss, err := findHighestMSSDestination(t.Context(), fd, dsts,
|
||||
excludeMark, defaultIPv4MTU, timeout, tracker, fw, logger)
|
||||
dst, mss, err := findHighestMSSDestination(t.Context(), familyToFD, dsts,
|
||||
excludeMark, defaultMTU, timeout, tracker, fw, logger)
|
||||
require.NoError(t, err, "finding highest MSS destination")
|
||||
assert.Contains(t, dsts, dst, "destination should be in the provided list")
|
||||
assert.Greater(t, mss, uint32(1000), "MSS should be greater than 1000")
|
||||
|
||||
@@ -33,17 +33,14 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
|
||||
firewall Firewall, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
family := constants.AF_INET
|
||||
if dsts[0].Addr().Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
families := ip.GetFamilies(dsts)
|
||||
familyToFD, stop, err := startRawSockets(families, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
return 0, fmt.Errorf("starting raw sockets: %w", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
tracker := newTracker(fd, family == constants.AF_INET)
|
||||
tracker := newTracker(familyToFD)
|
||||
|
||||
trackerCtx, trackerCancel := context.WithCancel(ctx)
|
||||
defer trackerCancel()
|
||||
@@ -62,7 +59,7 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
||||
mssCtx, mssCancel := context.WithTimeout(ctx, tryTimeout)
|
||||
defer mssCancel()
|
||||
go func() {
|
||||
dst, mss, err := findHighestMSSDestination(mssCtx, fd, dsts, excludeMark,
|
||||
dst, mss, err := findHighestMSSDestination(mssCtx, familyToFD, dsts, excludeMark,
|
||||
maxPossibleMTU, tryTimeout, tracker, firewall, logger)
|
||||
mssResultCh <- mssResult{dst: dst, mss: mss, err: err}
|
||||
}()
|
||||
@@ -83,6 +80,8 @@ func PathMTUDiscover(ctx context.Context, dsts []netip.AddrPort,
|
||||
maxPossibleMTU = ipHeaderLength + constants.BaseTCPHeaderLength + result.mss
|
||||
}
|
||||
|
||||
fd := familyToFD[ip.GetFamily(highestMSSDst)]
|
||||
|
||||
type pmtudResult struct {
|
||||
mtu uint32
|
||||
err error
|
||||
|
||||
@@ -10,6 +10,29 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
func startRawSockets(families []int, excludeMark int) (familyToSocket map[int]fileDescriptor, stop func(), err error) {
|
||||
familyToSocket = make(map[int]fileDescriptor, len(families))
|
||||
stops := make([]func(), 0, len(families))
|
||||
for _, family := range families {
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
if err != nil {
|
||||
for _, stop := range stops {
|
||||
stop()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("starting raw socket for family %d: %w", family, err)
|
||||
}
|
||||
stops = append(stops, stop)
|
||||
familyToSocket[family] = fd
|
||||
}
|
||||
|
||||
stop = func() {
|
||||
for _, stop := range stops {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
return familyToSocket, stop, nil
|
||||
}
|
||||
|
||||
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
|
||||
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
@@ -33,10 +56,10 @@ func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), er
|
||||
}
|
||||
|
||||
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
||||
err = setMTUDiscovery(fdPlatform)
|
||||
err = setMTUDiscovery(fdPlatform, family == constants.AF_INET)
|
||||
if err != nil {
|
||||
_ = closeSocket(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
|
||||
return 0, nil, fmt.Errorf("setting MTU discovery options: %w", err)
|
||||
}
|
||||
|
||||
// use polling because some Linux systems do not cancel
|
||||
|
||||
@@ -18,10 +18,21 @@ import (
|
||||
|
||||
func Test_PathMTUDiscover(t *testing.T) {
|
||||
t.Parallel()
|
||||
noopLogger := log.New(log.SetLevel(log.LevelDebug))
|
||||
|
||||
const tryTimeout = time.Second
|
||||
deadline, ok := t.Deadline()
|
||||
if ok {
|
||||
timeLeft := time.Until(deadline)
|
||||
const maxTimeNeeded = tryTimeout * 4 // MSS discovery + 3 MTU tries
|
||||
require.GreaterOrEqual(t, timeLeft, maxTimeNeeded,
|
||||
"not enough time remaining for TCP PMTUD test, need %s and got %s",
|
||||
maxTimeNeeded, timeLeft)
|
||||
}
|
||||
|
||||
logger := log.New(log.SetLevel(log.LevelDebug))
|
||||
|
||||
cmder := command.New()
|
||||
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||
fw, err := firewall.NewConfig(t.Context(), logger, cmder, nil, nil)
|
||||
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||
}
|
||||
@@ -32,11 +43,12 @@ func Test_PathMTUDiscover(t *testing.T) {
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 53),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2606:4700:4700::1111"), 443),
|
||||
netip.AddrPortFrom(netip.MustParseAddr("2001:4860:4860::8888"), 443),
|
||||
}
|
||||
const minMTU = constants.MinIPv6MTU
|
||||
const maxMTU = constants.MaxEthernetFrameSize
|
||||
const tryTimeout = time.Second
|
||||
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, noopLogger)
|
||||
mtu, err := PathMTUDiscover(t.Context(), dsts, minMTU, maxMTU, tryTimeout, fw, logger)
|
||||
require.NoError(t, err, "discovering path MTU")
|
||||
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
|
||||
t.Logf("discovered path MTU is %d", mtu)
|
||||
|
||||
@@ -13,6 +13,9 @@ func setMark(fd, excludeMark int) error {
|
||||
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||
func setMTUDiscovery(fd int, ipv4 bool) error {
|
||||
if ipv4 {
|
||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
|
||||
}
|
||||
|
||||
@@ -26,17 +26,15 @@ func Test_runTest(t *testing.T) {
|
||||
netlinker := netlink.New(noopLogger)
|
||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
||||
defaultMTU, err := findDefaultRouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default route MTU")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = constants.AF_INET
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
familyToFD, stop, err := startRawSockets([]int{constants.AF_INET, constants.AF_INET6}, excludeMark)
|
||||
require.NoError(t, err)
|
||||
|
||||
const ipv4 = true
|
||||
tracker := newTracker(fd, ipv4)
|
||||
tracker := newTracker(familyToFD)
|
||||
trackerCh := make(chan error)
|
||||
go func() {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
@@ -71,24 +69,24 @@ func Test_runTest(t *testing.T) {
|
||||
"remote_not_listening": {
|
||||
timeout: 50 * time.Millisecond,
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
mtu: defaultMTU - mtuSafetyBuffer,
|
||||
},
|
||||
"1.1.1.1:443": {
|
||||
timeout: 5 * time.Second,
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
mtu: defaultMTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"1.1.1.1:80": {
|
||||
timeout: 5 * time.Second,
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
mtu: defaultMTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"8.8.8.8:443": {
|
||||
timeout: 5 * time.Second,
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
mtu: defaultMTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
}
|
||||
@@ -99,6 +97,7 @@ func Test_runTest(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
dst := testCase.server
|
||||
fd := familyToFD[ip.GetFamily(dst)]
|
||||
|
||||
const proto = constants.IPPROTO_TCP
|
||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||
|
||||
@@ -6,6 +6,6 @@ func setMark(fd, excludeMark int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
func setMTUDiscovery(fd int, ipv4 bool) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func setMark(fd windows.Handle, _ int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd windows.Handle) error {
|
||||
func setMTUDiscovery(fd windows.Handle, ipv4 bool) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type tracker struct {
|
||||
fd fileDescriptor
|
||||
ipv4 bool
|
||||
familyToFD map[int]fileDescriptor
|
||||
mutex sync.RWMutex
|
||||
portsToDispatch map[uint32]dispatch
|
||||
}
|
||||
@@ -23,10 +22,9 @@ type dispatch struct {
|
||||
abort <-chan struct{}
|
||||
}
|
||||
|
||||
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
|
||||
func newTracker(familyToFD map[int]fileDescriptor) *tracker {
|
||||
return &tracker{
|
||||
fd: fd,
|
||||
ipv4: ipv4,
|
||||
familyToFD: familyToFD,
|
||||
portsToDispatch: make(map[uint32]dispatch),
|
||||
}
|
||||
}
|
||||
@@ -57,11 +55,36 @@ func (t *tracker) unregister(localPort, remotePort uint16) {
|
||||
delete(t.portsToDispatch, key)
|
||||
}
|
||||
|
||||
// listen listens for incoming TCP packets and dispatches them to the
|
||||
// correct channel based on the source and destination port.
|
||||
func (t *tracker) listen(ctx context.Context) (err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
type result struct {
|
||||
family int
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result)
|
||||
for family, fd := range t.familyToFD {
|
||||
go func(family int, fd fileDescriptor) {
|
||||
err := t.listenFD(ctx, fd, family == constants.AF_INET)
|
||||
resultCh <- result{family: family, err: err}
|
||||
}(family, fd)
|
||||
}
|
||||
|
||||
for range t.familyToFD {
|
||||
result := <-resultCh
|
||||
if err == nil && result.err != nil {
|
||||
cancel() // stop the other listener if it is still running
|
||||
err = fmt.Errorf("listening for family %d: %w", result.family, result.err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// listenFD listens for incoming TCP packets on the given file descriptor,
|
||||
// and dispatches them to the correct channel based on the source and destination port.
|
||||
// If the context has a deadline associated, this one is used on the socket.
|
||||
// Note it returns a nil error on context cancellation.
|
||||
func (t *tracker) listen(ctx context.Context) error {
|
||||
func (t *tracker) listenFD(ctx context.Context, fd fileDescriptor, ipv4 bool) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
for ctx.Err() == nil {
|
||||
if hasDeadline {
|
||||
@@ -69,14 +92,14 @@ func (t *tracker) listen(ctx context.Context) error {
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
err := setSocketTimeout(t.fd, remaining)
|
||||
err := setSocketTimeout(fd, remaining)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting socket receive timeout: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
reply := make([]byte, constants.MaxEthernetFrameSize)
|
||||
n, _, err := recvFrom(t.fd, reply, 0)
|
||||
n, _, err := recvFrom(fd, reply, 0)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, constants.EAGAIN), errors.Is(err, constants.EWOULDBLOCK):
|
||||
@@ -91,7 +114,7 @@ func (t *tracker) listen(ctx context.Context) error {
|
||||
}
|
||||
reply = reply[:n]
|
||||
|
||||
if t.ipv4 {
|
||||
if ipv4 {
|
||||
var ok bool
|
||||
reply, ok = stripIPv4Header(reply)
|
||||
if !ok {
|
||||
|
||||
Reference in New Issue
Block a user