mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
chore(pmtud/tcp): move test helpers in helpers_test.go
This commit is contained in:
@@ -0,0 +1,103 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
|
"github.com/qdm12/log"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopLogger struct{}
|
||||||
|
|
||||||
|
func (l *noopLogger) Patch(_ ...log.Option) {}
|
||||||
|
func (l *noopLogger) Debug(_ string) {}
|
||||||
|
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||||
|
func (l *noopLogger) Info(_ string) {}
|
||||||
|
func (l *noopLogger) Warn(_ string) {}
|
||||||
|
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||||
|
func (l *noopLogger) Error(_ string) {}
|
||||||
|
|
||||||
|
var errRouteNotFound = errors.New("route not found")
|
||||||
|
|
||||||
|
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
|
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting routes list: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
||||||
|
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting link by index: %w", err)
|
||||||
|
}
|
||||||
|
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
||||||
|
// or the IP header 16 bits will fail to fit that packet length value.
|
||||||
|
const maxMTU = 65535
|
||||||
|
return min(link.MTU, maxMTU), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
|
noopLogger := &noopLogger{}
|
||||||
|
routing := routing.New(netlinker, noopLogger)
|
||||||
|
defaultRoutes, err := routing.DefaultRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range defaultRoutes {
|
||||||
|
if route.Family != netlink.FamilyV4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||||
|
}
|
||||||
|
return link.MTU, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := unix.Close(fd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := &unix.SockaddrInet4{
|
||||||
|
Port: 0,
|
||||||
|
Addr: [4]byte{127, 0, 0, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.Bind(fd, addr)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sockAddr, err := unix.Getsockname(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
|
||||||
|
if !ok {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
t.Fatal("not an IPv4 address")
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(sockAddr4.Port) //nolint:gosec
|
||||||
|
}
|
||||||
@@ -5,21 +5,18 @@ package tcp
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
"github.com/qdm12/gluetun/internal/command"
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
"github.com/qdm12/gluetun/internal/firewall"
|
"github.com/qdm12/gluetun/internal/firewall"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
"github.com/qdm12/gluetun/internal/routing"
|
|
||||||
"github.com/qdm12/log"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_runTest(t *testing.T) {
|
func Test_runTest(t *testing.T) {
|
||||||
@@ -110,6 +107,7 @@ func Test_runTest(t *testing.T) {
|
|||||||
for name, testCase := range testCases {
|
for name, testCase := range testCases {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
dst := testCase.server
|
dst := testCase.server
|
||||||
|
|
||||||
@@ -124,11 +122,12 @@ func Test_runTest(t *testing.T) {
|
|||||||
err := revert(context.Background())
|
err := revert(context.Background())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err = runTest(ctx, dst, testCase.mtu, excludeMark,
|
err = runTest(ctx, dst, testCase.mtu, excludeMark,
|
||||||
fd, tracker, fw, noopLogger)
|
fd, tracker, fw, logger)
|
||||||
if testCase.success {
|
if testCase.success {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@@ -137,91 +136,3 @@ func Test_runTest(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errRouteNotFound = errors.New("route not found")
|
|
||||||
|
|
||||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
|
||||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("getting routes list: %w", err)
|
|
||||||
}
|
|
||||||
for _, route := range routes {
|
|
||||||
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
|
||||||
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("getting link by index: %w", err)
|
|
||||||
}
|
|
||||||
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
|
||||||
// or the IP header 16 bits will fail to fit that packet length value.
|
|
||||||
const maxMTU = 65535
|
|
||||||
return min(link.MTU, maxMTU), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
|
||||||
noopLogger := &noopLogger{}
|
|
||||||
routing := routing.New(netlinker, noopLogger)
|
|
||||||
defaultRoutes, err := routing.DefaultRoutes()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
|
||||||
}
|
|
||||||
for _, route := range defaultRoutes {
|
|
||||||
if route.Family != netlink.FamilyV4 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
|
||||||
}
|
|
||||||
return link.MTU, nil
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := unix.Close(fd)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
addr := &unix.SockaddrInet4{
|
|
||||||
Port: 0,
|
|
||||||
Addr: [4]byte{127, 0, 0, 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = unix.Bind(fd, addr)
|
|
||||||
if err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sockAddr, err := unix.Getsockname(fd)
|
|
||||||
if err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
|
|
||||||
if !ok {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
t.Fatal("not an IPv4 address")
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint16(sockAddr4.Port) //nolint:gosec
|
|
||||||
}
|
|
||||||
|
|
||||||
type noopLogger struct{}
|
|
||||||
|
|
||||||
func (l *noopLogger) Patch(_ ...log.Option) {}
|
|
||||||
func (l *noopLogger) Debug(_ string) {}
|
|
||||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
|
||||||
func (l *noopLogger) Info(_ string) {}
|
|
||||||
func (l *noopLogger) Warn(_ string) {}
|
|
||||||
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
|
||||||
func (l *noopLogger) Error(_ string) {}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user