mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
hotfix(pmtud/tcp): block kernel from racing to send RST packets
- this makes PMTUD TCP reliable - this only works on kernels with the mark module - on kernels without the mark module, the icmp pmtud mtu found is used
This commit is contained in:
@@ -1,5 +1,15 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Firewall interface {
|
||||
TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort,
|
||||
excludeMark int) (revert func(ctx context.Context) error, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
|
||||
@@ -19,7 +19,25 @@ type testUnit struct {
|
||||
}
|
||||
|
||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, logger Logger,
|
||||
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
const excludeMark = 4325
|
||||
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
err := revert(ctx)
|
||||
if err != nil {
|
||||
logger.Warnf("reverting firewall changes: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
|
||||
}
|
||||
|
||||
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
@@ -36,7 +54,7 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
if addrPort.Addr().Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
fd, stop, err := startRawSocket(family)
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
}
|
||||
@@ -80,8 +98,8 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
if tests[i].ok {
|
||||
stop()
|
||||
cancel()
|
||||
return PathMTUDiscover(ctx, addrPort,
|
||||
tests[i].mtu, tests[i+1].mtu-1, logger)
|
||||
return pathMTUDiscover(ctx, addrPort,
|
||||
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
|
||||
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
|
||||
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
||||
}
|
||||
|
||||
err = setMark(fdPlatform, excludeMark)
|
||||
if err != nil {
|
||||
_ = closeSocket(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting mark option on raw socket: %w", err)
|
||||
}
|
||||
|
||||
if family == constants.AF_INET {
|
||||
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
||||
} else {
|
||||
|
||||
@@ -2,6 +2,17 @@ package tcp
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
// setMark sets a mark on each packets sent through this socket.
|
||||
// This is used in conjunction with iptables to block outgoing kernel automated
|
||||
// RST packets, since the kernel is not aware of us handling the connection manually.
|
||||
// For example:
|
||||
// iptables -A OUTPUT -p tcp --tcp-flags RST RST -m mark ! --mark 123 -j DROP
|
||||
//
|
||||
//nolint:dupword
|
||||
func setMark(fd, excludeMark int) error {
|
||||
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
@@ -22,9 +24,32 @@ import (
|
||||
func Test_runTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Skipf("temporarily skipping test")
|
||||
serverAddrs := map[string]netip.AddrPort{
|
||||
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
}
|
||||
|
||||
noopLogger := &noopLogger{}
|
||||
|
||||
cmder := command.New()
|
||||
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||
}
|
||||
require.NoError(t, err, "creating firewall config")
|
||||
|
||||
// Prevent Kernel from sending RST packets back to servers
|
||||
const excludeMark = 4324
|
||||
for _, addrPort := range serverAddrs {
|
||||
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := revert(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
netlinker := netlink.New(noopLogger)
|
||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||
@@ -34,7 +59,7 @@ func Test_runTest(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = constants.AF_INET
|
||||
fd, stop, err := startRawSocket(family)
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
require.NoError(t, err)
|
||||
|
||||
const ipv4 = true
|
||||
@@ -44,6 +69,8 @@ func Test_runTest(t *testing.T) {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
}()
|
||||
|
||||
const mtuSafetyBuffer = 200
|
||||
|
||||
t.Cleanup(func() {
|
||||
stop()
|
||||
cancel() // stop listening
|
||||
@@ -72,30 +99,30 @@ func Test_runTest(t *testing.T) {
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
},
|
||||
"1.1.1.1:443": {
|
||||
timeout: time.Second,
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||
return serverAddrs["cloudflare-https"]
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"1.1.1.1:80": {
|
||||
timeout: time.Second,
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||
return serverAddrs["cloudflare-http"]
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"8.8.8.8:443": {
|
||||
timeout: time.Second,
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
|
||||
return serverAddrs["google-https"]
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
}
|
||||
@@ -103,9 +130,11 @@ func Test_runTest(t *testing.T) {
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := testCase.dst(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||
defer cancel()
|
||||
dst := testCase.dst(t)
|
||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||
if testCase.success {
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
package tcp
|
||||
|
||||
func setMark(fd, excludeMark int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@@ -31,6 +31,10 @@ func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from windows.Socka
|
||||
return windows.Recvfrom(windows.Handle(fd), p, flags)
|
||||
}
|
||||
|
||||
func setMark(fd windows.Handle, _ int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd windows.Handle) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user