chore(pmtud/tcp): restrict temp firewall rules to source ip and source port

This commit is contained in:
Quentin McGaw
2026-02-18 22:26:57 +00:00
parent 1c56189abc
commit bc79901f1e
9 changed files with 274 additions and 145 deletions
+31 -37
View File
@@ -14,6 +14,7 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
"github.com/qdm12/gluetun/internal/pmtud/ip"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
@@ -24,11 +25,7 @@ import (
func Test_runTest(t *testing.T) {
t.Parallel()
serverAddrs := map[string]netip.AddrPort{
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
}
localNonListenPort := reserveClosedPort(t)
noopLogger := &noopLogger{}
@@ -39,17 +36,6 @@ func Test_runTest(t *testing.T) {
}
require.NoError(t, err, "creating firewall config")
// Prevent Kernel from sending RST packets back to servers
const excludeMark = 4324
for _, addrPort := range serverAddrs {
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
require.NoError(t, err)
t.Cleanup(func() {
err := revert(context.Background())
assert.NoError(t, err)
})
}
netlinker := netlink.New(noopLogger)
loopbackMTU, err := findLoopbackMTU(netlinker)
require.NoError(t, err, "finding loopback IPv4 MTU")
@@ -59,6 +45,7 @@ func Test_runTest(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
const family = constants.AF_INET
const excludeMark = 4545
fd, stop, err := startRawSocket(family, excludeMark)
require.NoError(t, err)
@@ -69,6 +56,11 @@ func Test_runTest(t *testing.T) {
trackerCh <- tracker.listen(ctx)
}()
// Our local ethernet MTU could be 1500, and the server could advertise
// an MSS of 1400, but the real link to the server could have an MTU of 1300,
// so we need to adjust our test so it passes. We are not actually path MTU
// discovering here, just testing that we can receive the expected TCP packets
// for a given MTU.
const mtuSafetyBuffer = 200
t.Cleanup(func() {
@@ -80,48 +72,36 @@ func Test_runTest(t *testing.T) {
testCases := map[string]struct {
timeout time.Duration
dst func(t *testing.T) netip.AddrPort
server netip.AddrPort
mtu uint32
success bool
}{
"local_not_listening": {
timeout: time.Hour,
dst: func(t *testing.T) netip.AddrPort {
t.Helper()
port := reserveClosedPort(t)
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
},
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), localNonListenPort),
mtu: loopbackMTU,
success: true,
},
"remote_not_listening": {
timeout: 50 * time.Millisecond,
dst: func(_ *testing.T) netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
},
mtu: defaultIPv4MTU - mtuSafetyBuffer,
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
},
"1.1.1.1:443": {
timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return serverAddrs["cloudflare-https"]
},
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
"1.1.1.1:80": {
timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return serverAddrs["cloudflare-http"]
},
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
"8.8.8.8:443": {
timeout: 5 * time.Second,
dst: func(_ *testing.T) netip.AddrPort {
return serverAddrs["google-https"]
},
server: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
mtu: defaultIPv4MTU - mtuSafetyBuffer,
success: true,
},
@@ -131,11 +111,24 @@ func Test_runTest(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
dst := testCase.dst(t)
dst := testCase.server
const proto = constants.IPPROTO_TCP
src, cleanup, err := ip.SrcAddr(dst, proto)
require.NoError(t, err, "getting source address to reach remote server %s", dst)
t.Cleanup(cleanup)
revert, err := fw.TempDropOutputTCPRST(t.Context(), src, dst, excludeMark)
require.NoError(t, err)
t.Cleanup(func() {
err := revert(context.Background())
assert.NoError(t, err)
})
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
defer cancel()
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
err = runTest(ctx, dst, testCase.mtu, excludeMark,
fd, tracker, fw, noopLogger)
if testCase.success {
require.NoError(t, err)
} else {
@@ -230,4 +223,5 @@ func (l *noopLogger) Debug(_ string) {}
func (l *noopLogger) Debugf(_ string, _ ...any) {}
func (l *noopLogger) Info(_ string) {}
func (l *noopLogger) Warn(_ string) {}
func (l *noopLogger) Warnf(_ string, _ ...any) {}
func (l *noopLogger) Error(_ string) {}