Files
gluetun/internal/pmtud/tcp/helpers_test.go
T
Quentin McGaw 4a78989d9d chore: do not use sentinel errors when unneeded
- main reason being it's a burden to always define sentinel errors at global scope, wrap them with `%w` instead of using a string directly
- only use sentinel errors when it has to be checked using `errors.Is`
- replace all usage of these sentinel errors in `fmt.Errorf` with direct strings that were in the sentinel error
- exclude the sentinel error definition requirement from .golangci.yml
- update unit tests to use ContainersError instead of ErrorIs so it stays as a "not a change detector test" without requiring a sentinel error
2026-05-02 03:29:46 +00:00

141 lines
3.9 KiB
Go

package tcp
import (
"errors"
"fmt"
"sync"
"testing"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/firewall/iptables"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// testFirewall must be global to prevent parallel tests from interfering
// with each other since they would interact with the same filter table.
// The first test to use should initialize it, and the rest will reuse it.
var (
testFirewall *firewall.Config //nolint:gochecknoglobals
testFirewallOnce sync.Once //nolint:gochecknoglobals
)
// getFirewall returns a Firewall instance, initializing it if needed. If
// iptables is not supported, it skips the test.
func getFirewall(t *testing.T) *firewall.Config {
t.Helper()
testFirewallOnce.Do(func() {
noopLogger := &noopLogger{}
cmder := command.New()
var err error
testFirewall, err = firewall.NewConfig(t.Context(), noopLogger, noopLogger, cmder, nil, nil)
if errors.Is(err, iptables.ErrNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
})
if testFirewall == nil {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
return testFirewall
}
type noopLogger struct{}
func (l *noopLogger) Patch(_ ...log.Option) {}
func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {}
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
routes, err := netlinker.RouteList(netlink.FamilyV4)
if err != nil {
return 0, fmt.Errorf("getting routes list: %w", err)
}
for _, route := range routes {
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
link, err := netlinker.LinkByIndex(route.LinkIndex)
if err != nil {
return 0, fmt.Errorf("getting link by index: %w", err)
}
// Quirk: make sure it is maximum 65535, and not i.e. 65536
// or the IP header 16 bits will fail to fit that packet length value.
const maxMTU = 65535
return min(link.MTU, maxMTU), nil
}
}
return 0, errors.New("route not found: no loopback route found")
}
func findDefaultRouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
noopLogger := &noopLogger{}
routing := routing.New(netlinker, noopLogger)
defaultRoutes, err := routing.DefaultRoutes()
if err != nil {
return 0, fmt.Errorf("getting default routes: %w", err)
}
families := []uint8{constants.AF_INET, constants.AF_INET6}
for _, family := range families {
for _, route := range defaultRoutes {
if route.Family != family {
continue
}
link, err := netlinker.LinkByName(route.NetInterface)
if err != nil {
return 0, fmt.Errorf("getting link by name: %w", err)
}
mtu = max(mtu, link.MTU)
}
}
if mtu == 0 {
return 0, errors.New("route not found: no default route found")
}
return mtu, nil
}
func reserveClosedPort(t *testing.T) (port uint16) {
t.Helper()
fd, err := unix.Socket(constants.AF_INET, constants.SOCK_STREAM, constants.IPPROTO_TCP)
require.NoError(t, err)
t.Cleanup(func() {
err := unix.Close(fd)
assert.NoError(t, err)
})
addr := &unix.SockaddrInet4{
Port: 0,
Addr: [4]byte{127, 0, 0, 1},
}
err = unix.Bind(fd, addr)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr, err := unix.Getsockname(fd)
if err != nil {
_ = unix.Close(fd)
t.Fatal(err)
}
sockAddr4, ok := sockAddr.(*unix.SockaddrInet4)
if !ok {
_ = unix.Close(fd)
t.Fatal("not an IPv4 address")
}
return uint16(sockAddr4.Port) //nolint:gosec
}