Files
gluetun/internal/pmtud/tcp/tcp_test.go
T

228 lines
6.1 KiB
Go

//go:build linux
package tcp
import (
"context"
"errors"
"fmt"
"net/netip"
"testing"
"time"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func Test_runTest(t *testing.T) {
t.Parallel()
localNonListenPort := reserveClosedPort(t)
noopLogger := &noopLogger{}
cmder := command.New()
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
require.NoError(t, err, "finding default IPv4 route MTU")
ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET
const excludeMark = 4545
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err)
const ipv4 = true
tracker := newTracker(fd, ipv4)
trackerCh := make(chan error)
go func() {
trackerCh <- tracker.listen(ctx)
}()
// Our local ethernet MTU could be 1500, and the server could advertise
// an MSS of 1400, but the real link to the server could have an MTU of 1300,
// so we need to adjust our test so it passes. We are not actually path MTU
// discovering here, just testing that we can receive the expected TCP packets
// for a given MTU.
const mtuSafetyBuffer = 200
t.Cleanup(func() {
stop()
cancel() // stop listening
err = <-trackerCh
require.NoError(t, err)
})
testCases := map[string]struct {
timeout time.Duration
server netip.AddrPort
mtu uint32
success bool
}{
"local_not_listening": {
timeout: time.Hour,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort),
mtu: loopbackMTU,
success: true,
},
"remote_not_listening": {
timeout: 50 * time.Millisecond,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
},
"1.1.1.1:443": {
timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
"1.1.1.1:80": {
timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
"8.8.8.8:443": {
timeout: 5 * time.Second,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
dst := testCase.server
const proto = constants.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
require.NoError(t, err, "getting source address to reach remote server %s", dst)
t.Cleanup(cleanup)
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
require.NoError(t, err)
t.Cleanup(func() {
err := revert(context.Background())
assert.NoError(t, err)
})
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel()
err = runTest(ctx, dst, testCase.mtu, excludeMark,
fd, tracker, fw, noopLogger)
if testCase.success {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
var errRouteNotFound = errors.New("route not found")
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
}
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
for _, route := range defaultRoutes {
if route.Family != netlink.FamilyV4 {
continue
}
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
return link.MTU, nil
}
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := unix.Close(fd)
assert.NoError(t, err)
})
addr := &unix.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = unix.Bind(fd, addr)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr, err := unix.Getsockname(fd)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
if !ok {
_ = unix.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {}