diff --git a/internal/firewall/list.go b/internal/firewall/list.go index eb618f32..a81510a8 100644 --- a/internal/firewall/list.go +++ b/internal/firewall/list.go @@ -26,6 +26,7 @@ type chainRule struct { inputInterface string // input interface, for example "tun0" or "*"" outputInterface string // output interface, for example "eth0" or "*"" source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid. + sourcePort uint16 // Not specified if set to zero. destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid. destinationPort uint16 // Not specified if set to zero. redirPorts []uint16 // Not specified if empty. @@ -315,6 +316,12 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e return 0, fmt.Errorf("parsing destination port: %w", err) } consumed++ + case strings.HasPrefix(value, "spt:"): + rule.sourcePort, err = parseSourcePort(value) + if err != nil { + return 0, fmt.Errorf("parsing source port: %w", err) + } + consumed++ default: return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value) } @@ -337,6 +344,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e return 0, fmt.Errorf("parsing destination port: %w", err) } consumed++ + case strings.HasPrefix(value, "spt:"): + rule.sourcePort, err = parseSourcePort(value) + if err != nil { + return 0, fmt.Errorf("parsing source port: %w", err) + } + consumed++ case strings.HasPrefix(value, "flags:"): rule.tcpFlags, err = parseTCPFlags(value) if err != nil { @@ -352,12 +365,12 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e func parseDestinationPort(value string) (port uint16, err error) { value = strings.TrimPrefix(value, "dpt:") - const base, bitLength = 10, 16 - destinationPort, err := strconv.ParseUint(value, base, bitLength) - if err != nil { - return 0, fmt.Errorf("parsing %q: %w", value, err) - } - return uint16(destinationPort), nil + return parsePort(value) +} + +func parseSourcePort(value string) (port uint16, err error) { + value = strings.TrimPrefix(value, "spt:") + return parsePort(value) } var errTCPFlagsMalformed = errors.New("TCP flags are malformed") @@ -401,12 +414,10 @@ func parsePortsCSV(s string) (ports []uint16, err error) { fields := strings.Split(s, ",") ports = make([]uint16, len(fields)) for i, field := range fields { - const base, bitLength = 10, 16 - port, err := strconv.ParseUint(field, base, bitLength) + ports[i], err = parsePort(field) if err != nil { - return nil, fmt.Errorf("parsing port %q: %w", field, err) + return nil, err } - ports[i] = uint16(port) } return ports, nil } diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go index be454f57..948290d8 100644 --- a/internal/firewall/parse.go +++ b/internal/firewall/parse.go @@ -18,6 +18,7 @@ type iptablesInstruction struct { inputInterface string // for example "tun0" or "" for any interface. outputInterface string // for example "tun0" or "" for any interface. source netip.Prefix // if not valid, then it is unspecified. + sourcePort uint16 // if zero, there is no source port destination netip.Prefix // if not valid, then it is unspecified. destinationPort uint16 // if zero, there is no destination port toPorts []uint16 // if empty, there is no redirection @@ -45,6 +46,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) ( return false case i.destinationPort != rule.destinationPort: return false + case i.sourcePort != rule.sourcePort: + return false case !slices.Equal(i.toPorts, rule.redirPorts): return false case !slices.Equal(i.ctstate, rule.ctstate): @@ -99,25 +102,11 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er } func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) { - flag := fields[0] - - // All flags use one value after the flag, except the following: - switch flag { - case "--tcp-flags": // -m can have 1 or 2 values - const expected = 3 - if len(fields) < expected { - return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", - ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) - } - consumed = expected - default: - const expected = 2 - if len(fields) < expected { - return 0, fmt.Errorf("%w: flag %q requires a value, but got none", - ErrIptablesCommandMalformed, flag) - } - consumed = expected + consumed, err = preCheckInstructionFields(fields) + if err != nil { + return 0, err } + flag := fields[0] value := fields[1] switch flag { @@ -134,20 +123,9 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co case "-p", "--protocol": instruction.protocol = value case "-m", "--match": - consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields. - switch value { - case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded - case "mark": - switch fields[2] { - case "!": - consumed++ - instruction.mark.invert = true - default: - return 0, fmt.Errorf("%w: unsupported match mark with value: %s", - ErrIptablesCommandMalformed, fields[2]) - } - default: - return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value) + consumed, err = parseMatchModule(fields, instruction) + if err != nil { + return 0, fmt.Errorf("parsing match module: %w", err) } case "--mark": const base = 0 // auto-detect @@ -166,30 +144,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co if err != nil { return 0, fmt.Errorf("parsing source IP CIDR: %w", err) } + case "--sport": + instruction.sourcePort, err = parsePort(value) + if err != nil { + return 0, fmt.Errorf("parsing source port: %w", err) + } case "-d", "--destination": instruction.destination, err = parseIPPrefix(value) if err != nil { return 0, fmt.Errorf("parsing destination IP CIDR: %w", err) } case "--dport": - const base, bitLength = 10, 16 - destinationPort, err := strconv.ParseUint(value, base, bitLength) + instruction.destinationPort, err = parsePort(value) if err != nil { return 0, fmt.Errorf("parsing destination port: %w", err) } - instruction.destinationPort = uint16(destinationPort) case "--ctstate": instruction.ctstate = strings.Split(value, ",") case "--to-ports": - portStrings := strings.Split(value, ",") - instruction.toPorts = make([]uint16, len(portStrings)) - for i, portString := range portStrings { - const base, bitLength = 10, 16 - port, err := strconv.ParseUint(portString, base, bitLength) - if err != nil { - return 0, fmt.Errorf("parsing port redirection: %w", err) - } - instruction.toPorts[i] = uint16(port) + instruction.toPorts, err = parseToPorts(value) + if err != nil { + return 0, fmt.Errorf("parsing port redirection: %w", err) } case "--tcp-flags": mask, comparison := value, fields[2] @@ -203,6 +178,27 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co return consumed, nil } +func preCheckInstructionFields(fields []string) (consumed int, err error) { + flag := fields[0] + // All flags use one value after the flag, except the following: + switch flag { + case "--tcp-flags": // -m can have 1 or 2 values + const expected = 3 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", + ErrIptablesCommandMalformed, flag, strings.Join(fields, " ")) + } + return expected, nil + default: + const expected = 2 + if len(fields) < expected { + return 0, fmt.Errorf("%w: flag %q requires a value, but got none", + ErrIptablesCommandMalformed, flag) + } + return expected, nil + } +} + func parseIPPrefix(value string) (prefix netip.Prefix, err error) { slashIndex := strings.Index(value, "/") if slashIndex >= 0 { @@ -215,3 +211,52 @@ func parseIPPrefix(value string) (prefix netip.Prefix, err error) { } return netip.PrefixFrom(ip, ip.BitLen()), nil } + +func parsePort(value string) (port uint16, err error) { + const base, bitLength = 10, 16 + portValue, err := strconv.ParseUint(value, base, bitLength) + if err != nil { + return 0, err + } + return uint16(portValue), nil +} + +func parseMatchModule(fields []string, instruction *iptablesInstruction) ( + consumed int, err error, +) { + _ = fields[consumed] // -m or --match flag already detected + consumed++ + switch fields[consumed] { + case "tcp", "udp": + consumed++ + // for now ignore the protocol match since it's auto-loaded + // when parsing the -p/--protocol flag, and we don't need to + // parse it twice. + case "mark": + consumed++ + switch fields[consumed] { + case "!": + consumed++ + instruction.mark.invert = true + default: + return consumed, fmt.Errorf("%w: unsupported match mark with value: %s", + ErrIptablesCommandMalformed, fields[2]) + } + default: + return 0, fmt.Errorf("%w: unknown match value: %s", + ErrIptablesCommandMalformed, fields[consumed]) + } + return consumed, nil +} + +func parseToPorts(value string) (toPorts []uint16, err error) { + portStrings := strings.Split(value, ",") + toPorts = make([]uint16, len(portStrings)) + for i, portString := range portStrings { + toPorts[i], err = parsePort(portString) + if err != nil { + return nil, err + } + } + return toPorts, nil +} diff --git a/internal/firewall/tcp.go b/internal/firewall/tcp.go index 07f276ba..6d21e360 100644 --- a/internal/firewall/tcp.go +++ b/internal/firewall/tcp.go @@ -71,7 +71,7 @@ var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module li // This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection // by sending a TCP RST packet, although we want to handle the connection manually. func (c *Config) TempDropOutputTCPRST(ctx context.Context, - addrPort netip.AddrPort, excludeMark int) ( + src, dst netip.AddrPort, excludeMark int) ( revert func(ctx context.Context) error, err error, ) { _, err = os.Stat("/usr/lib/xtables/libxt_mark.so") @@ -79,11 +79,12 @@ func (c *Config) TempDropOutputTCPRST(ctx context.Context, return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing) } - const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword - instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark) - revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark) + const template = "%s OUTPUT -p tcp -s %s --sport %d -d %s --dport %d " + + "--tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword + instruction := fmt.Sprintf(template, "--append", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark) + revertInstruction := fmt.Sprintf(template, "--delete", src.Addr(), src.Port(), dst.Addr(), dst.Port(), excludeMark) run := c.runIptablesInstruction - if addrPort.Addr().Is6() { + if dst.Addr().Is6() { run = c.runIP6tablesInstruction } revert = func(ctx context.Context) error { diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index b102d552..937aad56 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -70,7 +70,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net const mtuMargin = 150 minMTU = max(maxPossibleMTU-mtuMargin, minMTU) } - mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger) + mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger) if err != nil { if errors.Is(err, firewall.ErrMarkMatchModuleMissing) { logger.Debugf("aborting TCP path MTU discovery: %s", err) diff --git a/internal/pmtud/tcp/interfaces.go b/internal/pmtud/tcp/interfaces.go index a54fd496..e7643089 100644 --- a/internal/pmtud/tcp/interfaces.go +++ b/internal/pmtud/tcp/interfaces.go @@ -6,7 +6,7 @@ import ( ) type Firewall interface { - TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort, + TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort, excludeMark int) (revert func(ctx context.Context) error, err error) } diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go index 8698a6b1..a36ec259 100644 --- a/internal/pmtud/tcp/multi.go +++ b/internal/pmtud/tcp/multi.go @@ -18,26 +18,61 @@ type testUnit struct { ok bool } -func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, - minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger, +func PathMTUDiscover(ctx context.Context, dst netip.AddrPort, + minMTU, maxPossibleMTU uint32, tryTimeout time.Duration, + firewall Firewall, logger Logger, ) (mtu uint32, err error) { - const excludeMark = 4325 - revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark) - if err != nil { - return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err) + family := constants.AF_INET + if dst.Addr().Is6() { + family = constants.AF_INET6 } - defer func() { - err := revert(ctx) - if err != nil { - logger.Warnf("reverting firewall changes: %s", err) - } + const excludeMark = 4325 + fd, stop, err := startRawSocket(family, excludeMark) + if err != nil { + return 0, fmt.Errorf("starting raw socket: %w", err) + } + defer stop() + + tracker := newTracker(fd, dst.Addr().Is4()) + + trackerCtx, trackerCancel := context.WithCancel(ctx) + defer trackerCancel() + trackerErrCh := make(chan error) + go func() { + trackerErrCh <- tracker.listen(trackerCtx) }() - return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger) + pmtudCtx, pmtudCancel := context.WithCancel(ctx) + defer pmtudCancel() + type result struct { + mtu uint32 + err error + } + pmtudResultCh := make(chan result) + go func() { + mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU, + excludeMark, tryTimeout, tracker, firewall, logger) + pmtudResultCh <- result{mtu: mtu, err: err} + }() + + select { + case err = <-trackerErrCh: + pmtudCancel() + <-pmtudResultCh + return 0, fmt.Errorf("listening for TCP replies: %w", err) + case res := <-pmtudResultCh: + trackerCancel() + <-trackerErrCh + return res.mtu, res.err + } } -func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, - minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger, +var errTimedOut = errors.New("timed out") + +func pathMTUDiscover(ctx context.Context, fd fileDescriptor, + dst netip.AddrPort, minMTU, maxPossibleMTU uint32, excludeMark int, + tryTimeout time.Duration, tracker *tracker, firewall Firewall, + logger Logger, ) (mtu uint32, err error) { mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU) if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU @@ -50,30 +85,14 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, tests[i] = testUnit{mtu: mtusToTest[i]} } - family := constants.AF_INET - if addrPort.Addr().Is6() { - family = constants.AF_INET6 - } - fd, stop, err := startRawSocket(family, excludeMark) - if err != nil { - return 0, fmt.Errorf("starting raw socket: %w", err) - } - defer stop() - - tracker := newTracker(fd, addrPort.Addr().Is4()) - - const timeout = time.Second - runCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - errCh := make(chan error) - go func() { - errCh <- tracker.listen(runCtx) - }() - + errCause := fmt.Errorf("%w: after %s", errTimedOut, tryTimeout) + runCtx, runCancel := context.WithTimeoutCause(ctx, tryTimeout, errCause) + defer runCancel() doneCh := make(chan struct{}) for i := range tests { go func(i int) { - err := runTest(runCtx, fd, tracker, src, dst, tests[i].mtu) + err := runTest(runCtx, dst, tests[i].mtu, excludeMark, + fd, tracker, firewall, logger) tests[i].ok = err == nil doneCh <- struct{}{} }(i) @@ -82,27 +101,33 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, i := 0 for i < len(tests) { select { + case <-runCtx.Done(): // timeout or parent context canceled + err = context.Cause(runCtx) + // collect remaining done signals + for i < len(tests) { + <-doneCh + i++ + } case <-doneCh: i++ - case err := <-errCh: - if err == nil { // timeout - cancel() - continue - } - return 0, fmt.Errorf("listening for TCP replies: %w", err) } } + if err != nil && !errors.Is(err, errTimedOut) { + // context is canceled but did not timeout after tryTimeout + return 0, fmt.Errorf("running MTU tests: %w", err) + } + if tests[len(tests)-1].ok { return tests[len(tests)-1].mtu, nil } for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd if tests[i].ok { - stop() - cancel() - return pathMTUDiscover(ctx, addrPort, - tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger) + runCancel() // just to release resources although runCtx is no longer used + return pathMTUDiscover(ctx, fd, dst, + tests[i].mtu, tests[i+1].mtu-1, excludeMark, + tryTimeout, tracker, firewall, logger) } } diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go index 0c88a841..b7b798ef 100644 --- a/internal/pmtud/tcp/tcp.go +++ b/internal/pmtud/tcp/tcp.go @@ -64,8 +64,9 @@ var ( // Craft and send a raw TCP packet to test the MTU. // It expects either an RST reply (if no server is listening) // or a SYN-ACK/ACK reply (if a server is listening). -func runTest(ctx context.Context, fd fileDescriptor, - tracker *tracker, dst netip.AddrPort, mtu uint32, +func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32, + excludeMark int, fd fileDescriptor, tracker *tracker, + firewall Firewall, logger Logger, ) error { const proto = constants.IPPROTO_TCP src, cleanup, err := ip.SrcAddr(dst, proto) @@ -74,6 +75,20 @@ func runTest(ctx context.Context, fd fileDescriptor, } defer cleanup() + revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark) + if err != nil { + return fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err) + } + defer func() { + // we don't want to skip reverting the firewall changes + // even if the context is already expired, so we use a + // background context here. + err := revert(context.Background()) + if err != nil { + logger.Warnf("reverting firewall changes: %s", err) + } + }() + ch := make(chan []byte) abort := make(chan struct{}) defer close(abort) diff --git a/internal/pmtud/tcp/tcp_integration_test.go b/internal/pmtud/tcp/tcp_integration_test.go new file mode 100644 index 00000000..3c05a723 --- /dev/null +++ b/internal/pmtud/tcp/tcp_integration_test.go @@ -0,0 +1,38 @@ +//go:build integration + +package tcp + +import ( + "errors" + "net/netip" + "testing" + "time" + + "github.com/qdm12/gluetun/internal/command" + "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PathMTUDiscover(t *testing.T) { + t.Parallel() + noopLogger := log.New(log.SetLevel(log.LevelDebug)) + + cmder := command.New() + fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) + if errors.Is(err, firewall.ErrIPTablesNotSupported) { + t.Skip("iptables not installed, skipping TCP PMTUD tests") + } + require.NoError(t, err, "creating firewall config") + + dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) + const minMTU = constants.MinIPv6MTU + const maxMTU = constants.MaxEthernetFrameSize + const tryTimeout = time.Second + mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger) + require.NoError(t, err, "discovering path MTU") + assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0") + t.Logf("discovered path MTU to %s is %d", dst, mtu) +} diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go index 81734214..948965a0 100644 --- a/internal/pmtud/tcp/tcp_test.go +++ b/internal/pmtud/tcp/tcp_test.go @@ -14,6 +14,7 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/log" "github.com/stretchr/testify/assert" @@ -24,11 +25,7 @@ import ( func Test_runTest(t *testing.T) { t.Parallel() - serverAddrs := map[string]netip.AddrPort{ - "cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), - "cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), - "google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), - } + localNonListenPort := reserveClosedPort(t) noopLogger := &noopLogger{} @@ -39,17 +36,6 @@ func Test_runTest(t *testing.T) { } require.NoError(t, err, "creating firewall config") - // Prevent Kernel from sending RST packets back to servers - const excludeMark = 4324 - for _, addrPort := range serverAddrs { - revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark) - require.NoError(t, err) - t.Cleanup(func() { - err := revert(context.Background()) - assert.NoError(t, err) - }) - } - netlinker := netlink.New(noopLogger) loopbackMTU, err := findLoopbackMTU(netlinker) require.NoError(t, err, "finding loopback IPv4 MTU") @@ -59,6 +45,7 @@ func Test_runTest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) const family = constants.AF_INET + const excludeMark = 4545 fd, stop, err := startRawSocket(family, excludeMark) require.NoError(t, err) @@ -69,6 +56,11 @@ func Test_runTest(t *testing.T) { trackerCh <- tracker.listen(ctx) }() + // Our local ethernet MTU could be 1500, and the server could advertise + // an MSS of 1400, but the real link to the server could have an MTU of 1300, + // so we need to adjust our test so it passes. We are not actually path MTU + // discovering here, just testing that we can receive the expected TCP packets + // for a given MTU. const mtuSafetyBuffer = 200 t.Cleanup(func() { @@ -80,48 +72,36 @@ func Test_runTest(t *testing.T) { testCases := map[string]struct { timeout time.Duration - dst func(t *testing.T) netip.AddrPort + server netip.AddrPort mtu uint32 success bool }{ "local_not_listening": { timeout: time.Hour, - dst: func(t *testing.T) netip.AddrPort { - t.Helper() - port := reserveClosedPort(t) - return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port) - }, + server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort), mtu: loopbackMTU, success: true, }, "remote_not_listening": { timeout: 50 * time.Millisecond, - dst: func(_ *testing.T) netip.AddrPort { - return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345) - }, - mtu: defaultIPv4MTU - mtuSafetyBuffer, + server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345), + mtu: defaultIPv4MTU - mtuSafetyBuffer, }, "1.1.1.1:443": { timeout: 5 * time.Second, - dst: func(_ *testing.T) netip.AddrPort { - return serverAddrs["cloudflare-https"] - }, + server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "1.1.1.1:80": { timeout: 5 * time.Second, - dst: func(_ *testing.T) netip.AddrPort { - return serverAddrs["cloudflare-http"] - }, + server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "8.8.8.8:443": { timeout: 5 * time.Second, - dst: func(_ *testing.T) netip.AddrPort { - return serverAddrs["google-https"] - }, + server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, @@ -131,11 +111,24 @@ func Test_runTest(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - dst := testCase.dst(t) + dst := testCase.server + + const proto = constants.IPPROTO_TCP + src, cleanup, err := ip.SrcAddr(dst, proto) + require.NoError(t, err, "getting source address to reach remote server %s", dst) + t.Cleanup(cleanup) + + revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark) + require.NoError(t, err) + t.Cleanup(func() { + err := revert(context.Background()) + assert.NoError(t, err) + }) ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) defer cancel() - err := runTest(ctx, fd, tracker, dst, testCase.mtu) + err = runTest(ctx, dst, testCase.mtu, excludeMark, + fd, tracker, fw, noopLogger) if testCase.success { require.NoError(t, err) } else { @@ -230,4 +223,5 @@ func (l *noopLogger) Debug(_ string) {} func (l *noopLogger) Debugf(_ string, _ ...any) {} func (l *noopLogger) Info(_ string) {} func (l *noopLogger) Warn(_ string) {} +func (l *noopLogger) Warnf(_ string, _ ...any) {} func (l *noopLogger) Error(_ string) {}