//go:build linux package tcp import ( "context" "errors" "net/netip" "testing" "time" gomock "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/command" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_runTest(t *testing.T) { t.Parallel() localNonListenPort := reserveClosedPort(t) noopLogger := &noopLogger{} cmder := command.New() fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) if errors.Is(err, firewall.ErrIPTablesNotSupported) { t.Skip("iptables not installed, skipping TCP PMTUD tests") } require.NoError(t, err, "creating firewall config") netlinker := netlink.New(noopLogger) loopbackMTU, err := findLoopbackMTU(netlinker) require.NoError(t, err, "finding loopback IPv4 MTU") defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) require.NoError(t, err, "finding default IPv4 route MTU") ctx, cancel := context.WithCancel(t.Context()) const family = constants.AF_INET const excludeMark = 4545 fd, stop, err := startRawSocket(family, excludeMark) require.NoError(t, err) const ipv4 = true tracker := newTracker(fd, ipv4) trackerCh := make(chan error) go func() { trackerCh <- tracker.listen(ctx) }() // Our local ethernet MTU could be 1500, and the server could advertise // an MSS of 1400, but the real link to the server could have an MTU of 1300, // so we need to adjust our test so it passes. We are not actually path MTU // discovering here, just testing that we can receive the expected TCP packets // for a given MTU. const mtuSafetyBuffer = 200 t.Cleanup(func() { stop() cancel() // stop listening err = <-trackerCh require.NoError(t, err) }) testCases := map[string]struct { timeout time.Duration server netip.AddrPort mtu uint32 success bool }{ "local_not_listening": { timeout: time.Hour, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort), mtu: loopbackMTU, success: true, }, "remote_not_listening": { timeout: 50 * time.Millisecond, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345), mtu: defaultIPv4MTU - mtuSafetyBuffer, }, "1.1.1.1:443": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "1.1.1.1:80": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "8.8.8.8:443": { timeout: 5 * time.Second, server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, } for name, testCase := range testCases { t.Run(name, func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) dst := testCase.server const proto = constants.IPPROTO_TCP src, cleanup, err := ip.SrcAddr(dst, proto) require.NoError(t, err, "getting source address to reach remote server %s", dst) t.Cleanup(cleanup) revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark) require.NoError(t, err) t.Cleanup(func() { err := revert(context.Background()) assert.NoError(t, err) }) logger := NewMockLogger(ctrl) ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) defer cancel() err = runTest(ctx, dst, testCase.mtu, excludeMark, fd, tracker, fw, logger) if testCase.success { require.NoError(t, err) } else { require.Error(t, err) } }) } }