chore(pmtud/tcp): restrict temp firewall rules to source ip and source port

This commit is contained in:
Quentin McGaw
2026-02-18 22:26:57 +00:00
parent 1c56189abc
commit bc79901f1e
9 changed files with 274 additions and 145 deletions
+70 -45
View File
@@ -18,26 +18,61 @@ type testUnit struct {
ok bool
}
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
func PathMTUDiscover(ctx context.Context, dst netip.AddrPort,
minMTU, maxPossibleMTU uint32, tryTimeout time.Duration,
firewall Firewall, logger Logger,
) (mtu uint32, err error) {
const excludeMark = 4325
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
if err != nil {
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
family := constants.AF_INET
if dst.Addr().Is6() {
family = constants.AF_INET6
}
defer func() {
err := revert(ctx)
if err != nil {
logger.Warnf("reverting firewall changes: %s", err)
}
const excludeMark = 4325
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, dst.Addr().Is4())
trackerCtx, trackerCancel := context.WithCancel(ctx)
defer trackerCancel()
trackerErrCh := make(chan error)
go func() {
trackerErrCh <- tracker.listen(trackerCtx)
}()
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
pmtudCtx, pmtudCancel := context.WithCancel(ctx)
defer pmtudCancel()
type result struct {
mtu uint32
err error
}
pmtudResultCh := make(chan result)
go func() {
mtu, err := pathMTUDiscover(pmtudCtx, fd, dst, minMTU, maxPossibleMTU,
excludeMark, tryTimeout, tracker, firewall, logger)
pmtudResultCh <- result{mtu: mtu, err: err}
}()
select {
case err = <-trackerErrCh:
pmtudCancel()
<-pmtudResultCh
return 0, fmt.Errorf("listening for TCP replies: %w", err)
case res := <-pmtudResultCh:
trackerCancel()
<-trackerErrCh
return res.mtu, res.err
}
}
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
var errTimedOut = errors.New("timed out")
func pathMTUDiscover(ctx context.Context, fd fileDescriptor,
dst netip.AddrPort, minMTU, maxPossibleMTU uint32, excludeMark int,
tryTimeout time.Duration, tracker *tracker, firewall Firewall,
logger Logger,
) (mtu uint32, err error) {
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
@@ -50,30 +85,14 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
tests[i] = testUnit{mtu: mtusToTest[i]}
}
family := constants.AF_INET
if addrPort.Addr().Is6() {
family = constants.AF_INET6
}
fd, stop, err := startRawSocket(family, excludeMark)
if err != nil {
return 0, fmt.Errorf("starting raw socket: %w", err)
}
defer stop()
tracker := newTracker(fd, addrPort.Addr().Is4())
const timeout = time.Second
runCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
errCh := make(chan error)
go func() {
errCh <- tracker.listen(runCtx)
}()
errCause := fmt.Errorf("%w: after %s", errTimedOut, tryTimeout)
runCtx, runCancel := context.WithTimeoutCause(ctx, tryTimeout, errCause)
defer runCancel()
doneCh := make(chan struct{})
for i := range tests {
go func(i int) {
err := runTest(runCtx, fd, tracker, src, dst, tests[i].mtu)
err := runTest(runCtx, dst, tests[i].mtu, excludeMark,
fd, tracker, firewall, logger)
tests[i].ok = err == nil
doneCh <- struct{}{}
}(i)
@@ -82,27 +101,33 @@ func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
i := 0
for i < len(tests) {
select {
case <-runCtx.Done(): // timeout or parent context canceled
err = context.Cause(runCtx)
// collect remaining done signals
for i < len(tests) {
<-doneCh
i++
}
case <-doneCh:
i++
case err := <-errCh:
if err == nil { // timeout
cancel()
continue
}
return 0, fmt.Errorf("listening for TCP replies: %w", err)
}
}
if err != nil && !errors.Is(err, errTimedOut) {
// context is canceled but did not timeout after tryTimeout
return 0, fmt.Errorf("running MTU tests: %w", err)
}
if tests[len(tests)-1].ok {
return tests[len(tests)-1].mtu, nil
}
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
stop()
cancel()
return pathMTUDiscover(ctx, addrPort,
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
runCancel() // just to release resources although runCtx is no longer used
return pathMTUDiscover(ctx, fd, dst,
tests[i].mtu, tests[i+1].mtu-1, excludeMark,
tryTimeout, tracker, firewall, logger)
}
}