From fb85ae79d1bcb050f24a6aaac0dfb03dc3eaf788 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 19 Feb 2026 13:07:15 +0000 Subject: [PATCH] chore(pmtud/tcp): move test helpers in helpers_test.go --- internal/pmtud/tcp/helpers_test.go | 103 +++++++++++++++++++++++++++++ internal/pmtud/tcp/tcp_test.go | 97 ++------------------------- 2 files changed, 107 insertions(+), 93 deletions(-) create mode 100644 internal/pmtud/tcp/helpers_test.go diff --git a/internal/pmtud/tcp/helpers_test.go b/internal/pmtud/tcp/helpers_test.go new file mode 100644 index 00000000..53d4fbb3 --- /dev/null +++ b/internal/pmtud/tcp/helpers_test.go @@ -0,0 +1,103 @@ +package tcp + +import ( + "errors" + "fmt" + "testing" + + "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/routing" + "github.com/qdm12/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +type noopLogger struct{} + +func (l *noopLogger) Patch(_ ...log.Option) {} +func (l *noopLogger) Debug(_ string) {} +func (l *noopLogger) Debugf(_ string, _ ...any) {} +func (l *noopLogger) Info(_ string) {} +func (l *noopLogger) Warn(_ string) {} +func (l *noopLogger) Warnf(_ string, _ ...any) {} +func (l *noopLogger) Error(_ string) {} + +var errRouteNotFound = errors.New("route not found") + +func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { + routes, err := netlinker.RouteList(netlink.FamilyV4) + if err != nil { + return 0, fmt.Errorf("getting routes list: %w", err) + } + for _, route := range routes { + if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() { + link, err := netlinker.LinkByIndex(route.LinkIndex) + if err != nil { + return 0, fmt.Errorf("getting link by index: %w", err) + } + // Quirk: make sure it is maximum 65535, and not i.e. 65536 + // or the IP header 16 bits will fail to fit that packet length value. + const maxMTU = 65535 + return min(link.MTU, maxMTU), nil + } + } + return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) +} + +func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { + noopLogger := &noopLogger{} + routing := routing.New(netlinker, noopLogger) + defaultRoutes, err := routing.DefaultRoutes() + if err != nil { + return 0, fmt.Errorf("getting default routes: %w", err) + } + for _, route := range defaultRoutes { + if route.Family != netlink.FamilyV4 { + continue + } + link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface) + if err != nil { + return 0, fmt.Errorf("getting link by name: %w", err) + } + return link.MTU, nil + } + return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) +} + +func reserveClosedPort(t *testing.T) (port uint16) { + t.Helper() + + fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP) + require.NoError(t, err) + t.Cleanup(func() { + err := unix.Close(fd) + assert.NoError(t, err) + }) + + addr := &unix.SockaddrInet4{ + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + + err = unix.Bind(fd, addr) + if err != nil { + _ = unix.Close(fd) + t.Fatal(err) + } + + sockAddr, err := unix.Getsockname(fd) + if err != nil { + _ = unix.Close(fd) + t.Fatal(err) + } + + sockAddr4, ok := sockAddr.(*unix.SockaddrInet4) + if !ok { + _ = unix.Close(fd) + t.Fatal("not an IPv4 address") + } + + return uint16(sockAddr4.Port) //nolint:gosec +} diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go index 948965a0..5947436c 100644 --- a/internal/pmtud/tcp/tcp_test.go +++ b/internal/pmtud/tcp/tcp_test.go @@ -5,21 +5,18 @@ package tcp import ( "context" "errors" - "fmt" "net/netip" "testing" "time" + gomock "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/ip" - "github.com/qdm12/gluetun/internal/routing" - "github.com/qdm12/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" ) func Test_runTest(t *testing.T) { @@ -110,6 +107,7 @@ func Test_runTest(t *testing.T) { for name, testCase := range testCases { t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) dst := testCase.server @@ -124,11 +122,12 @@ func Test_runTest(t *testing.T) { err := revert(context.Background()) assert.NoError(t, err) }) + logger := NewMockLogger(ctrl) ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) defer cancel() err = runTest(ctx, dst, testCase.mtu, excludeMark, - fd, tracker, fw, noopLogger) + fd, tracker, fw, logger) if testCase.success { require.NoError(t, err) } else { @@ -137,91 +136,3 @@ func Test_runTest(t *testing.T) { }) } } - -var errRouteNotFound = errors.New("route not found") - -func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { - routes, err := netlinker.RouteList(netlink.FamilyV4) - if err != nil { - return 0, fmt.Errorf("getting routes list: %w", err) - } - for _, route := range routes { - if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() { - link, err := netlinker.LinkByIndex(route.LinkIndex) - if err != nil { - return 0, fmt.Errorf("getting link by index: %w", err) - } - // Quirk: make sure it is maximum 65535, and not i.e. 65536 - // or the IP header 16 bits will fail to fit that packet length value. - const maxMTU = 65535 - return min(link.MTU, maxMTU), nil - } - } - return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) -} - -func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { - noopLogger := &noopLogger{} - routing := routing.New(netlinker, noopLogger) - defaultRoutes, err := routing.DefaultRoutes() - if err != nil { - return 0, fmt.Errorf("getting default routes: %w", err) - } - for _, route := range defaultRoutes { - if route.Family != netlink.FamilyV4 { - continue - } - link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface) - if err != nil { - return 0, fmt.Errorf("getting link by name: %w", err) - } - return link.MTU, nil - } - return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) -} - -func reserveClosedPort(t *testing.T) (port uint16) { - t.Helper() - - fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP) - require.NoError(t, err) - t.Cleanup(func() { - err := unix.Close(fd) - assert.NoError(t, err) - }) - - addr := &unix.SockaddrInet4{ - Port: 0, - Addr: [4]byte{127, 0, 0, 1}, - } - - err = unix.Bind(fd, addr) - if err != nil { - _ = unix.Close(fd) - t.Fatal(err) - } - - sockAddr, err := unix.Getsockname(fd) - if err != nil { - _ = unix.Close(fd) - t.Fatal(err) - } - - sockAddr4, ok := sockAddr.(*unix.SockaddrInet4) - if !ok { - _ = unix.Close(fd) - t.Fatal("not an IPv4 address") - } - - return uint16(sockAddr4.Port) //nolint:gosec -} - -type noopLogger struct{} - -func (l *noopLogger) Patch(_ ...log.Option) {} -func (l *noopLogger) Debug(_ string) {} -func (l *noopLogger) Debugf(_ string, _ ...any) {} -func (l *noopLogger) Info(_ string) {} -func (l *noopLogger) Warn(_ string) {} -func (l *noopLogger) Warnf(_ string, _ ...any) {} -func (l *noopLogger) Error(_ string) {}