hotfix(pmtud/tcp): block kernel from racing to send RST packets

- this makes PMTUD TCP reliable
- this only works on kernels with the mark module
- on kernels without the mark module, the icmp pmtud mtu found is used
This commit is contained in:
Quentin McGaw
2026-02-17 19:33:51 +00:00
parent 5f903d1fbf
commit 04d7cef294
15 changed files with 226 additions and 27 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
RUN apk add wireguard-tools htop openssl tcpdump RUN apk add wireguard-tools htop openssl tcpdump iptables
+1 -1
View File
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
# Note: findutils needed to have xargs support `-d` flag for mocks stage. # Note: findutils needed to have xargs support `-d` flag for mocks stage.
RUN apk --update add git g++ findutils RUN apk --update add git g++ findutils iptables
ENV CGO_ENABLED=0 ENV CGO_ENABLED=0
COPY --from=golangci-lint /bin /go/bin/golangci-lint COPY --from=golangci-lint /bin /go/bin/golangci-lint
COPY --from=mockgen /bin /go/bin/mockgen COPY --from=mockgen /bin /go/bin/mockgen
+40
View File
@@ -31,6 +31,12 @@ type chainRule struct {
redirPorts []uint16 // Not specified if empty. redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
tcpFlags tcpFlags tcpFlags tcpFlags
mark mark
}
type mark struct {
invert bool
value uint
} }
var ErrChainListMalformed = errors.New("iptables chain list output is malformed") var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
@@ -278,6 +284,14 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
i++ i++
rule.ctstate = strings.Split(optionalFields[i], ",") rule.ctstate = strings.Split(optionalFields[i], ",")
i++ i++
case "mark":
i++
mark, consumed, err := parseMark(optionalFields[i:])
if err != nil {
return fmt.Errorf("parsing mark: %w", err)
}
rule.mark = mark
i += consumed
default: default:
return fmt.Errorf("%w: unexpected optional field: %s", return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i]) ErrChainRuleMalformed, optionalFields[i])
@@ -397,6 +411,32 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
return ports, nil return ports, nil
} }
var errMarkValueMalformed = errors.New("mark value is malformed")
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
switch optionalFields[consumed] {
case "match":
consumed++
if optionalFields[consumed] == "!" {
m.invert = true
consumed++
}
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
if err != nil {
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
}
m.value = uint(value)
consumed++
default:
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
ErrChainRuleMalformed, optionalFields[consumed])
}
return m, consumed, nil
}
var ErrLineNumberIsZero = errors.New("line number is zero") var ErrLineNumberIsZero = errors.New("line number is zero")
func parseLineNumber(s string) (n uint16, err error) { func parseLineNumber(s string) (n uint16, err error) {
+28 -2
View File
@@ -23,6 +23,7 @@ type iptablesInstruction struct {
toPorts []uint16 // if empty, there is no redirection toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate ctstate []string // if empty, there is no ctstate
tcpFlags tcpFlags tcpFlags tcpFlags
mark mark
} }
func (i *iptablesInstruction) setDefaults() { func (i *iptablesInstruction) setDefaults() {
@@ -59,6 +60,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) || case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison): !slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
return false return false
case i.mark != rule.mark:
return false
default: default:
return true return true
} }
@@ -100,7 +103,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
// All flags use one value after the flag, except the following: // All flags use one value after the flag, except the following:
switch flag { switch flag {
case "--tcp-flags": case "--tcp-flags": // -m can have 1 or 2 values
const expected = 3 const expected = 3
if len(fields) < expected { if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s", return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
@@ -130,7 +133,30 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
instruction.target = value instruction.target = value
case "-p", "--protocol": case "-p", "--protocol":
instruction.protocol = value instruction.protocol = value
case "-m", "--match": // ignore match case "-m", "--match":
consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields.
switch value {
case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded
case "mark":
switch fields[2] {
case "!":
consumed++
instruction.mark.invert = true
default:
return 0, fmt.Errorf("%w: unsupported match mark with value: %s",
ErrIptablesCommandMalformed, fields[2])
}
default:
return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value)
}
case "--mark":
const base = 0 // auto-detect
const bits = 32
value, err := strconv.ParseUint(value, base, bits)
if err != nil {
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
}
instruction.mark.value = uint(value)
case "-i", "--in-interface": case "-i", "--in-interface":
instruction.inputInterface = value instruction.inputInterface = value
case "-o", "--out-interface": case "-o", "--out-interface":
+35
View File
@@ -1,8 +1,11 @@
package firewall package firewall
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"os"
) )
type tcpFlags struct { type tcpFlags struct {
@@ -60,3 +63,35 @@ func parseTCPFlag(s string) (tcpFlag, error) {
} }
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s) return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
} }
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
// for any TCP packets not marked with the excludeMark given.
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
// by sending a TCP RST packet, although we want to handle the connection manually.
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
addrPort netip.AddrPort, excludeMark int) (
revert func(ctx context.Context) error, err error,
) {
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
if err != nil && errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
}
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
run := c.runIptablesInstruction
if addrPort.Addr().Is6() {
run = c.runIP6tablesInstruction
}
revert = func(ctx context.Context) error {
return run(ctx, revertInstruction)
}
err = run(ctx, instruction)
if err != nil {
return nil, fmt.Errorf("running instruction: %w", err)
}
return revert, nil
}
+15 -2
View File
@@ -7,11 +7,14 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/icmp" "github.com/qdm12/gluetun/internal/pmtud/icmp"
"github.com/qdm12/gluetun/internal/pmtud/tcp" "github.com/qdm12/gluetun/internal/pmtud/tcp"
) )
var ErrPMTUDFailICMPAndTCP = errors.New("PMTUD failed with both ICMP and TCP")
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP. // PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
// Multiple ICMP addresses and TCP addresses can be specified for redundancy. // Multiple ICMP addresses and TCP addresses can be specified for redundancy.
// ICMP PMTUD is run first. If successful, the range of possible MTU values to // ICMP PMTUD is run first. If successful, the range of possible MTU values to
@@ -23,7 +26,7 @@ import (
// If the logger is nil, a no-op logger is used. // If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined. // It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) ( physicalLinkMTU uint32, tryTimeout time.Duration, fw tcp.Firewall, logger Logger) (
mtu uint32, err error, mtu uint32, err error,
) { ) {
if physicalLinkMTU == 0 { if physicalLinkMTU == 0 {
@@ -67,13 +70,23 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
const mtuMargin = 150 const mtuMargin = 150
minMTU = max(maxPossibleMTU-mtuMargin, minMTU) minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
} }
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger) mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger)
if err != nil { if err != nil {
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
logger.Debugf("aborting TCP path MTU discovery: %s", err)
if icmpSuccess {
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
}
return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP)
}
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err) logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
continue continue
} }
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu) logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
return mtu, nil return mtu, nil
} }
// TCP PMTUD failed for all addresses for external reasons,
// so do not take the risk and return an error.
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err) return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
} }
+10
View File
@@ -1,5 +1,15 @@
package tcp package tcp
import (
"context"
"net/netip"
)
type Firewall interface {
TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort,
excludeMark int) (revert func(ctx context.Context) error, err error)
}
type Logger interface { type Logger interface {
Debug(msg string) Debug(msg string)
Debugf(msg string, args ...any) Debugf(msg string, args ...any)
+22 -4
View File
@@ -19,7 +19,25 @@ type testUnit struct {
} }
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, logger Logger, minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
) (mtu uint32, err error) {
const excludeMark = 4325
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
if err != nil {
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
}
defer func() {
err := revert(ctx)
if err != nil {
logger.Warnf("reverting firewall changes: %s", err)
}
}()
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
}
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
) (mtu uint32, err error) { ) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU) mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
@@ -36,7 +54,7 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
if addrPort.Addr().Is6() { if addrPort.Addr().Is6() {
family = constants.AF_INET6 family = constants.AF_INET6
} }
fd, stop, err := startRawSocket(family) fd, stop, err := startRawSocket(family, excludeMark)
if err != nil { if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err) return 0, fmt.Errorf("starting raw socket: %w", err)
} }
@@ -80,8 +98,8 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
if tests[i].ok { if tests[i].ok {
stop() stop()
cancel() cancel()
return PathMTUDiscover(ctx, addrPort, return pathMTUDiscover(ctx, addrPort,
tests[i].mtu, tests[i+1].mtu-1, logger) tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
} }
} }
+7 -1
View File
@@ -10,12 +10,18 @@ import (
"github.com/qdm12/gluetun/internal/pmtud/ip" "github.com/qdm12/gluetun/internal/pmtud/ip"
) )
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) { func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP) fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("creating raw socket: %w", err) return 0, nil, fmt.Errorf("creating raw socket: %w", err)
} }
err = setMark(fdPlatform, excludeMark)
if err != nil {
_ = closeSocket(fdPlatform)
return 0, nil, fmt.Errorf("setting mark option on raw socket: %w", err)
}
if family == constants.AF_INET { if family == constants.AF_INET {
err = ip.SetIPv4HeaderIncluded(fdPlatform) err = ip.SetIPv4HeaderIncluded(fdPlatform)
} else { } else {
+11
View File
@@ -2,6 +2,17 @@ package tcp
import "golang.org/x/sys/unix" import "golang.org/x/sys/unix"
// setMark sets a mark on each packets sent through this socket.
// This is used in conjunction with iptables to block outgoing kernel automated
// RST packets, since the kernel is not aware of us handling the connection manually.
// For example:
// iptables -A OUTPUT -p tcp --tcp-flags RST RST -m mark ! --mark 123 -j DROP
//
//nolint:dupword
func setMark(fd, excludeMark int) error {
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
}
func setMTUDiscovery(fd int) error { func setMTUDiscovery(fd int) error {
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
} }
+42 -13
View File
@@ -10,6 +10,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/qdm12/gluetun/internal/command"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
@@ -22,9 +24,32 @@ import (
func Test_runTest(t *testing.T) { func Test_runTest(t *testing.T) {
t.Parallel() t.Parallel()
t.Skipf("temporarily skipping test") serverAddrs := map[string]netip.AddrPort{
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
noopLogger := &noopLogger{} noopLogger := &noopLogger{}
cmder := command.New()
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
t.Skip("iptables not installed, skipping TCP PMTUD tests")
}
require.NoError(t, err, "creating firewall config")
// Prevent Kernel from sending RST packets back to servers
const excludeMark = 4324
for _, addrPort := range serverAddrs {
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
require.NoError(t, err)
t.Cleanup(func() {
err := revert(context.Background())
assert.NoError(t, err)
})
}
netlinker := netlink.New(noopLogger) netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker) loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU") require.NoError(t, err, "finding loopback IPv4 MTU")
@@ -34,7 +59,7 @@ func Test_runTest(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET const family = constants.AF_INET
fd, stop, err := startRawSocket(family) fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err) require.NoError(t, err)
const ipv4 = true const ipv4 = true
@@ -44,6 +69,8 @@ func Test_runTest(t *testing.T) {
trackerCh <- tracker.listen(ctx) trackerCh <- tracker.listen(ctx)
}() }()
const mtuSafetyBuffer = 200
t.Cleanup(func() { t.Cleanup(func() {
stop() stop()
cancel() // stop listening cancel() // stop listening
@@ -72,30 +99,30 @@ func Test_runTest(t *testing.T) {
dst: func(_ *testing.T) netip.AddrPort { dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345) return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
}, },
mtu: defaultIPv4MTU, mtu: defaultIPv4MTU - mtuSafetyBuffer,
}, },
"1.1.1.1:443": { "1.1.1.1:443": {
timeout: time.Second, timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort { dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) return serverAddrs["cloudflare-https"]
}, },
mtu: defaultIPv4MTU, mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true, success: true,
}, },
"1.1.1.1:80": { "1.1.1.1:80": {
timeout: time.Second, timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort { dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) return serverAddrs["cloudflare-http"]
}, },
mtu: defaultIPv4MTU, mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true, success: true,
}, },
"8.8.8.8:443": { "8.8.8.8:443": {
timeout: time.Second, timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort { dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443) return serverAddrs["google-https"]
}, },
mtu: defaultIPv4MTU, mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true, success: true,
}, },
} }
@@ -103,9 +130,11 @@ func Test_runTest(t *testing.T) {
for name, testCase := range testCases { for name, testCase := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
dst := testCase.dst(t)
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel() defer cancel()
dst := testCase.dst(t)
err := runTest(ctx, fd, tracker, dst, testCase.mtu) err := runTest(ctx, fd, tracker, dst, testCase.mtu)
if testCase.success { if testCase.success {
require.NoError(t, err) require.NoError(t, err)
+4
View File
@@ -2,6 +2,10 @@
package tcp package tcp
func setMark(fd, excludeMark int) error {
panic("not implemented")
}
func setMTUDiscovery(fd int) error { func setMTUDiscovery(fd int) error {
panic("not implemented") panic("not implemented")
} }
+4
View File
@@ -31,6 +31,10 @@ func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from windows.Socka
return windows.Recvfrom(windows.Handle(fd), p, flags) return windows.Recvfrom(windows.Handle(fd), p, flags)
} }
func setMark(fd windows.Handle, _ int) error {
panic("not implemented")
}
func setMTUDiscovery(fd windows.Handle) error { func setMTUDiscovery(fd windows.Handle) error {
panic("not implemented") panic("not implemented")
} }
+2
View File
@@ -8,6 +8,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/tcp"
portforward "github.com/qdm12/gluetun/internal/portforward" portforward "github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
@@ -17,6 +18,7 @@ type Firewall interface {
SetVPNConnection(ctx context.Context, connection models.Connection, interfaceName string) error SetVPNConnection(ctx context.Context, connection models.Connection, interfaceName string) error
SetAllowedPort(ctx context.Context, port uint16, interfaceName string) error SetAllowedPort(ctx context.Context, port uint16, interfaceName string) error
RemoveAllowedPort(ctx context.Context, port uint16) error RemoveAllowedPort(ctx context.Context, port uint16) error
tcp.Firewall
} }
type Routing interface { type Routing interface {
+4 -3
View File
@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud" "github.com/qdm12/gluetun/internal/pmtud"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/tcp"
"github.com/qdm12/gluetun/internal/version" "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log" "github.com/qdm12/log"
) )
@@ -58,7 +59,7 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType, err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs, data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
l.netLinker, l.routing, mtuLogger) l.netLinker, l.routing, l.fw, mtuLogger)
if err != nil { if err != nil {
mtuLogger.Error(err.Error()) mtuLogger.Error(err.Error())
} }
@@ -156,7 +157,7 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
func updateToMaxMTU(ctx context.Context, vpnInterface string, func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
netlinker NetLinker, routing Routing, logger *log.Logger, netlinker NetLinker, routing Routing, firewall tcp.Firewall, logger *log.Logger,
) error { ) error {
logger.Info("finding maximum MTU, this can take up to 6 seconds") logger.Info("finding maximum MTU, this can take up to 6 seconds")
@@ -185,7 +186,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
const pingTimeout = time.Second const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs, vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
vpnLinkMTU, pingTimeout, logger) vpnLinkMTU, pingTimeout, firewall, logger)
if err != nil { if err != nil {
vpnLinkMTU = originalMTU vpnLinkMTU = originalMTU
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)", logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",