mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
chore(pmtud/tcp): restrict temp firewall rules to source ip and source port
This commit is contained in:
@@ -70,7 +70,7 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
|
||||
const mtuMargin = 150
|
||||
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||
}
|
||||
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger)
|
||||
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, tryTimeout, fw, logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
|
||||
logger.Debugf("aborting TCP path MTU discovery: %s", err)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type Firewall interface {
|
||||
TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort,
|
||||
TempDropOutputTCPRST(ctx context.Context, src, dst netip.AddrPort,
|
||||
excludeMark int) (revert func(ctx context.Context) error, err error)
|
||||
}
|
||||
|
||||
|
||||
+70
-45
@@ -18,26 +18,61 @@ type testUnit struct {
|
||||
ok bool
|
||||
}
|
||||
|
||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
|
||||
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
|
||||
firewall Firewall, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
const excludeMark = 4325
|
||||
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
||||
family := constants.AF_INET
|
||||
if dst.Addr().Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
defer func() {
|
||||
err := revert(ctx)
|
||||
if err != nil {
|
||||
logger.Warnf("reverting firewall changes: %s", err)
|
||||
}
|
||||
const excludeMark = 4325
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
tracker := newTracker(fd, dst.Addr().Is4())
|
||||
|
||||
trackerCtx, trackerCancel := context.WithCancel(ctx)
|
||||
defer trackerCancel()
|
||||
trackerErrCh := make(chan error)
|
||||
go func() {
|
||||
trackerErrCh <- tracker.listen(trackerCtx)
|
||||
}()
|
||||
|
||||
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
|
||||
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
|
||||
defer pmtudCancel()
|
||||
type result struct {
|
||||
mtu uint32
|
||||
err error
|
||||
}
|
||||
pmtudResultCh := make(chan result)
|
||||
go func() {
|
||||
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU,
|
||||
excludeMark, tryTimeout, tracker, firewall, logger)
|
||||
pmtudResultCh <- result{mtu: mtu, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-trackerErrCh:
|
||||
pmtudCancel()
|
||||
<-pmtudResultCh
|
||||
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||
case res := <-pmtudResultCh:
|
||||
trackerCancel()
|
||||
<-trackerErrCh
|
||||
return res.mtu, res.err
|
||||
}
|
||||
}
|
||||
|
||||
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
|
||||
var errTimedOut = errors.New("timed out")
|
||||
|
||||
func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
|
||||
dst netip.AddrPort, minMTU, maxPossibleMTU uint32, excludeMark int,
|
||||
tryTimeout time.Duration, tracker *tracker, firewall Firewall,
|
||||
logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
@@ -50,30 +85,14 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
tests[i] = testUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
family := constants.AF_INET
|
||||
if addrPort.Addr().Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
tracker := newTracker(fd, addrPort.Addr().Is4())
|
||||
|
||||
const timeout = time.Second
|
||||
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- tracker.listen(runCtx)
|
||||
}()
|
||||
|
||||
errCause := fmt.Errorf("%w: after %s", errTimedOut, tryTimeout)
|
||||
runCtx, runCancel := context.WithTimeoutCause(ctx, tryTimeout, errCause)
|
||||
defer runCancel()
|
||||
doneCh := make(chan struct{})
|
||||
for i := range tests {
|
||||
go func(i int) {
|
||||
err := runTest(runCtx, fd, tracker, src, dst, tests[i].mtu)
|
||||
err := runTest(runCtx, dst, tests[i].mtu, excludeMark,
|
||||
fd, tracker, firewall, logger)
|
||||
tests[i].ok = err == nil
|
||||
doneCh <- struct{}{}
|
||||
}(i)
|
||||
@@ -82,27 +101,33 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
i := 0
|
||||
for i < len(tests) {
|
||||
select {
|
||||
case <-runCtx.Done(): // timeout or parent context canceled
|
||||
err = context.Cause(runCtx)
|
||||
// collect remaining done signals
|
||||
for i < len(tests) {
|
||||
<-doneCh
|
||||
i++
|
||||
}
|
||||
case <-doneCh:
|
||||
i++
|
||||
case err := <-errCh:
|
||||
if err == nil { // timeout
|
||||
cancel()
|
||||
continue
|
||||
}
|
||||
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, errTimedOut) {
|
||||
// context is canceled but did not timeout after tryTimeout
|
||||
return 0, fmt.Errorf("running MTU tests: %w", err)
|
||||
}
|
||||
|
||||
if tests[len(tests)-1].ok {
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
}
|
||||
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
stop()
|
||||
cancel()
|
||||
return pathMTUDiscover(ctx, addrPort,
|
||||
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
|
||||
runCancel() // just to release resources although runCtx is no longer used
|
||||
return pathMTUDiscover(ctx, fd, dst,
|
||||
tests[i].mtu, tests[i+1].mtu-1, excludeMark,
|
||||
tryTimeout, tracker, firewall, logger)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,8 +64,9 @@ var (
|
||||
// Craft and send a raw TCP packet to test the MTU.
|
||||
// It expects either an RST reply (if no server is listening)
|
||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||
func runTest(ctx context.Context, fd fileDescriptor,
|
||||
tracker *tracker, dst netip.AddrPort, mtu uint32,
|
||||
func runTest(ctx context.Context, dst netip.AddrPort, mtu uint32,
|
||||
excludeMark int, fd fileDescriptor, tracker *tracker,
|
||||
firewall Firewall, logger Logger,
|
||||
) error {
|
||||
const proto = constants.IPPROTO_TCP
|
||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||
@@ -74,6 +75,20 @@ func runTest(ctx context.Context, fd fileDescriptor,
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
revert, err := firewall.TempDropOutputTCPRST(ctx, src, dst, excludeMark)
|
||||
if err != nil {
|
||||
return fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
// we don't want to skip reverting the firewall changes
|
||||
// even if the context is already expired, so we use a
|
||||
// background context here.
|
||||
err := revert(context.Background())
|
||||
if err != nil {
|
||||
logger.Warnf("reverting firewall changes: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := make(chan []byte)
|
||||
abort := make(chan struct{})
|
||||
defer close(abort)
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
//go:build integration
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_PathMTUDiscover(t *testing.T) {
|
||||
t.Parallel()
|
||||
noopLogger := log.New(log.SetLevel(log.LevelDebug))
|
||||
|
||||
cmder := command.New()
|
||||
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||
}
|
||||
require.NoError(t, err, "creating firewall config")
|
||||
|
||||
dst := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||
const minMTU = constants.MinIPv6MTU
|
||||
const maxMTU = constants.MaxEthernetFrameSize
|
||||
const tryTimeout = time.Second
|
||||
mtu, err := PathMTUDiscover(t.Context(), dst, minMTU, maxMTU, tryTimeout, fw, noopLogger)
|
||||
require.NoError(t, err, "discovering path MTU")
|
||||
assert.Greater(t, mtu, uint32(0), "MTU should be greater than 0")
|
||||
t.Logf("discovered path MTU to %s is %d", dst, mtu)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -24,11 +25,7 @@ import (
|
||||
func Test_runTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverAddrs := map[string]netip.AddrPort{
|
||||
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
}
|
||||
localNonListenPort := reserveClosedPort(t)
|
||||
|
||||
noopLogger := &noopLogger{}
|
||||
|
||||
@@ -39,17 +36,6 @@ func Test_runTest(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, err, "creating firewall config")
|
||||
|
||||
// Prevent Kernel from sending RST packets back to servers
|
||||
const excludeMark = 4324
|
||||
for _, addrPort := range serverAddrs {
|
||||
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := revert(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
netlinker := netlink.New(noopLogger)
|
||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||
@@ -59,6 +45,7 @@ func Test_runTest(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = constants.AF_INET
|
||||
const excludeMark = 4545
|
||||
fd, stop, err := startRawSocket(family, excludeMark)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -69,6 +56,11 @@ func Test_runTest(t *testing.T) {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
}()
|
||||
|
||||
// Our local ethernet MTU could be 1500, and the server could advertise
|
||||
// an MSS of 1400, but the real link to the server could have an MTU of 1300,
|
||||
// so we need to adjust our test so it passes. We are not actually path MTU
|
||||
// discovering here, just testing that we can receive the expected TCP packets
|
||||
// for a given MTU.
|
||||
const mtuSafetyBuffer = 200
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -80,48 +72,36 @@ func Test_runTest(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
timeout time.Duration
|
||||
dst func(t *testing.T) netip.AddrPort
|
||||
server netip.AddrPort
|
||||
mtu uint32
|
||||
success bool
|
||||
}{
|
||||
"local_not_listening": {
|
||||
timeout: time.Hour,
|
||||
dst: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
port := reserveClosedPort(t)
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
|
||||
},
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort),
|
||||
mtu: loopbackMTU,
|
||||
success: true,
|
||||
},
|
||||
"remote_not_listening": {
|
||||
timeout: 50 * time.Millisecond,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||
},
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
},
|
||||
"1.1.1.1:443": {
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return serverAddrs["cloudflare-https"]
|
||||
},
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"1.1.1.1:80": {
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return serverAddrs["cloudflare-http"]
|
||||
},
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
"8.8.8.8:443": {
|
||||
timeout: 5 * time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return serverAddrs["google-https"]
|
||||
},
|
||||
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||
success: true,
|
||||
},
|
||||
@@ -131,11 +111,24 @@ func Test_runTest(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := testCase.dst(t)
|
||||
dst := testCase.server
|
||||
|
||||
const proto = constants.IPPROTO_TCP
|
||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||
require.NoError(t, err, "getting source address to reach remote server %s", dst)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := revert(context.Background())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||
defer cancel()
|
||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||
err = runTest(ctx, dst, testCase.mtu, excludeMark,
|
||||
fd, tracker, fw, noopLogger)
|
||||
if testCase.success {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
@@ -230,4 +223,5 @@ func (l *noopLogger) Debug(_ string) {}
|
||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Info(_ string) {}
|
||||
func (l *noopLogger) Warn(_ string) {}
|
||||
func (l *noopLogger) Warnf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Error(_ string) {}
|
||||
|
||||
Reference in New Issue
Block a user