diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 8ac79bf4..5aa8f60d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,2 +1,2 @@ FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine -RUN apk add wireguard-tools htop openssl tcpdump +RUN apk add wireguard-tools htop openssl tcpdump iptables diff --git a/Dockerfile b/Dockerfile index be4f3f82..e5dcf2cb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION} FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate # Note: findutils needed to have xargs support `-d` flag for mocks stage. -RUN apk --update add git g++ findutils +RUN apk --update add git g++ findutils iptables ENV CGO_ENABLED=0 COPY --from=golangci-lint /bin /go/bin/golangci-lint COPY --from=mockgen /bin /go/bin/mockgen diff --git a/internal/firewall/list.go b/internal/firewall/list.go index 75e1955a..eb618f32 100644 --- a/internal/firewall/list.go +++ b/internal/firewall/list.go @@ -31,6 +31,12 @@ type chainRule struct { redirPorts []uint16 // Not specified if empty. ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. tcpFlags tcpFlags + mark mark +} + +type mark struct { + invert bool + value uint } var ErrChainListMalformed = errors.New("iptables chain list output is malformed") @@ -278,6 +284,14 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err i++ rule.ctstate = strings.Split(optionalFields[i], ",") i++ + case "mark": + i++ + mark, consumed, err := parseMark(optionalFields[i:]) + if err != nil { + return fmt.Errorf("parsing mark: %w", err) + } + rule.mark = mark + i += consumed default: return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, optionalFields[i]) @@ -397,6 +411,32 @@ func parsePortsCSV(s string) (ports []uint16, err error) { return ports, nil } +var errMarkValueMalformed = errors.New("mark value is malformed") + +func parseMark(optionalFields []string) (m mark, consumed int, err error) { + switch optionalFields[consumed] { + case "match": + consumed++ + if optionalFields[consumed] == "!" { + m.invert = true + consumed++ + } + + const base = 0 // auto-detect + const bits = 32 + value, err := strconv.ParseUint(optionalFields[consumed], base, bits) + if err != nil { + return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed]) + } + m.value = uint(value) + consumed++ + default: + return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s", + ErrChainRuleMalformed, optionalFields[consumed]) + } + return m, consumed, nil +} + var ErrLineNumberIsZero = errors.New("line number is zero") func parseLineNumber(s string) (n uint16, err error) { diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go index c5cb2d88..be454f57 100644 --- a/internal/firewall/parse.go +++ b/internal/firewall/parse.go @@ -23,6 +23,7 @@ type iptablesInstruction struct { toPorts []uint16 // if empty, there is no redirection ctstate []string // if empty, there is no ctstate tcpFlags tcpFlags + mark mark } func (i *iptablesInstruction) setDefaults() { @@ -59,6 +60,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) ( case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) || !slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison): return false + case i.mark != rule.mark: + return false default: return true } @@ -100,7 +103,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co // All flags use one value after the flag, except the following: switch flag { - case "--tcp-flags": + case "--tcp-flags": // -m can have 1 or 2 values const expected = 3 if len(fields) < expected { return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", @@ -130,7 +133,30 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co instruction.target = value case "-p", "--protocol": instruction.protocol = value - case "-m", "--match": // ignore match + case "-m", "--match": + consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields. + switch value { + case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded + case "mark": + switch fields[2] { + case "!": + consumed++ + instruction.mark.invert = true + default: + return 0, fmt.Errorf("%w: unsupported match mark with value: %s", + ErrIptablesCommandMalformed, fields[2]) + } + default: + return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value) + } + case "--mark": + const base = 0 // auto-detect + const bits = 32 + value, err := strconv.ParseUint(value, base, bits) + if err != nil { + return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err) + } + instruction.mark.value = uint(value) case "-i", "--in-interface": instruction.inputInterface = value case "-o", "--out-interface": diff --git a/internal/firewall/tcp.go b/internal/firewall/tcp.go index 7c38f9ca..07f276ba 100644 --- a/internal/firewall/tcp.go +++ b/internal/firewall/tcp.go @@ -1,8 +1,11 @@ package firewall import ( + "context" "errors" "fmt" + "net/netip" + "os" ) type tcpFlags struct { @@ -60,3 +63,35 @@ func parseTCPFlag(s string) (tcpFlag, error) { } return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s) } + +var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so") + +// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port, +// for any TCP packets not marked with the excludeMark given. +// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection +// by sending a TCP RST packet, although we want to handle the connection manually. +func (c *Config) TempDropOutputTCPRST(ctx context.Context, + addrPort netip.AddrPort, excludeMark int) ( + revert func(ctx context.Context) error, err error, +) { + _, err = os.Stat("/usr/lib/xtables/libxt_mark.so") + if err != nil && errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing) + } + + const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword + instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark) + revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark) + run := c.runIptablesInstruction + if addrPort.Addr().Is6() { + run = c.runIP6tablesInstruction + } + revert = func(ctx context.Context) error { + return run(ctx, revertInstruction) + } + err = run(ctx, instruction) + if err != nil { + return nil, fmt.Errorf("running instruction: %w", err) + } + return revert, nil +} diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index 7d07f427..b102d552 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -7,11 +7,14 @@ import ( "net/netip" "time" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/icmp" "github.com/qdm12/gluetun/internal/pmtud/tcp" ) +var ErrPMTUDFailICMPAndTCP = errors.New("PMTUD failed with both ICMP and TCP") + // PathMTUDiscover discovers the maximum MTU using both ICMP and TCP. // Multiple ICMP addresses and TCP addresses can be specified for redundancy. // ICMP PMTUD is run first. If successful, the range of possible MTU values to @@ -23,7 +26,7 @@ import ( // If the logger is nil, a no-op logger is used. // It returns [ErrMTUNotFound] if the MTU could not be determined. func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, - physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) ( + physicalLinkMTU uint32, tryTimeout time.Duration, fw tcp.Firewall, logger Logger) ( mtu uint32, err error, ) { if physicalLinkMTU == 0 { @@ -67,13 +70,23 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net const mtuMargin = 150 minMTU = max(maxPossibleMTU-mtuMargin, minMTU) } - mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger) + mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger) if err != nil { + if errors.Is(err, firewall.ErrMarkMatchModuleMissing) { + logger.Debugf("aborting TCP path MTU discovery: %s", err) + if icmpSuccess { + return maxPossibleMTU, nil // only rely on ICMP PMTUD results + } + return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP) + } logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err) continue } logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu) return mtu, nil } + + // TCP PMTUD failed for all addresses for external reasons, + // so do not take the risk and return an error. return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err) } diff --git a/internal/pmtud/tcp/interfaces.go b/internal/pmtud/tcp/interfaces.go index 2709d75f..a54fd496 100644 --- a/internal/pmtud/tcp/interfaces.go +++ b/internal/pmtud/tcp/interfaces.go @@ -1,5 +1,15 @@ package tcp +import ( + "context" + "net/netip" +) + +type Firewall interface { + TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort, + excludeMark int) (revert func(ctx context.Context) error, err error) +} + type Logger interface { Debug(msg string) Debugf(msg string, args ...any) diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go index 0849daca..1a150865 100644 --- a/internal/pmtud/tcp/multi.go +++ b/internal/pmtud/tcp/multi.go @@ -19,7 +19,25 @@ type testUnit struct { } func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, - minMTU, maxPossibleMTU uint32, logger Logger, + minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger, +) (mtu uint32, err error) { + const excludeMark = 4325 + revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark) + if err != nil { + return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err) + } + defer func() { + err := revert(ctx) + if err != nil { + logger.Warnf("reverting firewall changes: %s", err) + } + }() + + return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger) +} + +func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, + minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger, ) (mtu uint32, err error) { mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU) if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU @@ -36,7 +54,7 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, if addrPort.Addr().Is6() { family = constants.AF_INET6 } - fd, stop, err := startRawSocket(family) + fd, stop, err := startRawSocket(family, excludeMark) if err != nil { return 0, fmt.Errorf("starting raw socket: %w", err) } @@ -80,8 +98,8 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, if tests[i].ok { stop() cancel() - return PathMTUDiscover(ctx, addrPort, - tests[i].mtu, tests[i+1].mtu-1, logger) + return pathMTUDiscover(ctx, addrPort, + tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger) } } diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go index 0e5c128f..c0676e16 100644 --- a/internal/pmtud/tcp/tcp.go +++ b/internal/pmtud/tcp/tcp.go @@ -10,12 +10,18 @@ import ( "github.com/qdm12/gluetun/internal/pmtud/ip" ) -func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) { +func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) { fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP) if err != nil { return 0, nil, fmt.Errorf("creating raw socket: %w", err) } + err = setMark(fdPlatform, excludeMark) + if err != nil { + _ = closeSocket(fdPlatform) + return 0, nil, fmt.Errorf("setting mark option on raw socket: %w", err) + } + if family == constants.AF_INET { err = ip.SetIPv4HeaderIncluded(fdPlatform) } else { diff --git a/internal/pmtud/tcp/tcp_linux.go b/internal/pmtud/tcp/tcp_linux.go index 69bfa902..b751fe1e 100644 --- a/internal/pmtud/tcp/tcp_linux.go +++ b/internal/pmtud/tcp/tcp_linux.go @@ -2,6 +2,17 @@ package tcp import "golang.org/x/sys/unix" +// setMark sets a mark on each packets sent through this socket. +// This is used in conjunction with iptables to block outgoing kernel automated +// RST packets, since the kernel is not aware of us handling the connection manually. +// For example: +// iptables -A OUTPUT -p tcp --tcp-flags RST RST -m mark ! --mark 123 -j DROP +// +//nolint:dupword +func setMark(fd, excludeMark int) error { + return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark) +} + func setMTUDiscovery(fd int) error { return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) } diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go index 2abe72a4..81734214 100644 --- a/internal/pmtud/tcp/tcp_test.go +++ b/internal/pmtud/tcp/tcp_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/qdm12/gluetun/internal/command" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/routing" @@ -22,9 +24,32 @@ import ( func Test_runTest(t *testing.T) { t.Parallel() - t.Skipf("temporarily skipping test") + serverAddrs := map[string]netip.AddrPort{ + "cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80), + "cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443), + "google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443), + } noopLogger := &noopLogger{} + + cmder := command.New() + fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil) + if errors.Is(err, firewall.ErrIPTablesNotSupported) { + t.Skip("iptables not installed, skipping TCP PMTUD tests") + } + require.NoError(t, err, "creating firewall config") + + // Prevent Kernel from sending RST packets back to servers + const excludeMark = 4324 + for _, addrPort := range serverAddrs { + revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark) + require.NoError(t, err) + t.Cleanup(func() { + err := revert(context.Background()) + assert.NoError(t, err) + }) + } + netlinker := netlink.New(noopLogger) loopbackMTU, err := findLoopbackMTU(netlinker) require.NoError(t, err, "finding loopback IPv4 MTU") @@ -34,7 +59,7 @@ func Test_runTest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) const family = constants.AF_INET - fd, stop, err := startRawSocket(family) + fd, stop, err := startRawSocket(family, excludeMark) require.NoError(t, err) const ipv4 = true @@ -44,6 +69,8 @@ func Test_runTest(t *testing.T) { trackerCh <- tracker.listen(ctx) }() + const mtuSafetyBuffer = 200 + t.Cleanup(func() { stop() cancel() // stop listening @@ -72,30 +99,30 @@ func Test_runTest(t *testing.T) { dst: func(_ *testing.T) netip.AddrPort { return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345) }, - mtu: defaultIPv4MTU, + mtu: defaultIPv4MTU - mtuSafetyBuffer, }, "1.1.1.1:443": { - timeout: time.Second, + timeout: 5 * time.Second, dst: func(_ *testing.T) netip.AddrPort { - return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) + return serverAddrs["cloudflare-https"] }, - mtu: defaultIPv4MTU, + mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "1.1.1.1:80": { - timeout: time.Second, + timeout: 5 * time.Second, dst: func(_ *testing.T) netip.AddrPort { - return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) + return serverAddrs["cloudflare-http"] }, - mtu: defaultIPv4MTU, + mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, "8.8.8.8:443": { - timeout: time.Second, + timeout: 5 * time.Second, dst: func(_ *testing.T) netip.AddrPort { - return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443) + return serverAddrs["google-https"] }, - mtu: defaultIPv4MTU, + mtu: defaultIPv4MTU - mtuSafetyBuffer, success: true, }, } @@ -103,9 +130,11 @@ func Test_runTest(t *testing.T) { for name, testCase := range testCases { t.Run(name, func(t *testing.T) { t.Parallel() + + dst := testCase.dst(t) + ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) defer cancel() - dst := testCase.dst(t) err := runTest(ctx, fd, tracker, dst, testCase.mtu) if testCase.success { require.NoError(t, err) diff --git a/internal/pmtud/tcp/tcp_unspecified.go b/internal/pmtud/tcp/tcp_unspecified.go index ff50b22e..8943e9d1 100644 --- a/internal/pmtud/tcp/tcp_unspecified.go +++ b/internal/pmtud/tcp/tcp_unspecified.go @@ -2,6 +2,10 @@ package tcp +func setMark(fd, excludeMark int) error { + panic("not implemented") +} + func setMTUDiscovery(fd int) error { panic("not implemented") } diff --git a/internal/pmtud/tcp/tcp_windows.go b/internal/pmtud/tcp/tcp_windows.go index 842093f5..eec7ff8f 100644 --- a/internal/pmtud/tcp/tcp_windows.go +++ b/internal/pmtud/tcp/tcp_windows.go @@ -31,6 +31,10 @@ func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from windows.Socka return windows.Recvfrom(windows.Handle(fd), p, flags) } +func setMark(fd windows.Handle, _ int) error { + panic("not implemented") +} + func setMTUDiscovery(fd windows.Handle) error { panic("not implemented") } diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index b2a3618c..3238c2e2 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -8,6 +8,7 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/pmtud/tcp" portforward "github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider/utils" @@ -17,6 +18,7 @@ type Firewall interface { SetVPNConnection(ctx context.Context, connection models.Connection, interfaceName string) error SetAllowedPort(ctx context.Context, port uint16, interfaceName string) error RemoveAllowedPort(ctx context.Context, port uint16) error + tcp.Firewall } type Routing interface { diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 6af4df60..a6be2aef 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -10,6 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/pmtud" pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/tcp" "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/log" ) @@ -58,7 +59,7 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType, data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs, - l.netLinker, l.routing, mtuLogger) + l.netLinker, l.routing, l.fw, mtuLogger) if err != nil { mtuLogger.Error(err.Error()) } @@ -156,7 +157,7 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) { func updateToMaxMTU(ctx context.Context, vpnInterface string, vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, - netlinker NetLinker, routing Routing, logger *log.Logger, + netlinker NetLinker, routing Routing, firewall tcp.Firewall, logger *log.Logger, ) error { logger.Info("finding maximum MTU, this can take up to 6 seconds") @@ -185,7 +186,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, const pingTimeout = time.Second vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs, - vpnLinkMTU, pingTimeout, logger) + vpnLinkMTU, pingTimeout, firewall, logger) if err != nil { vpnLinkMTU = originalMTU logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",