mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
hotfix(pmtud/tcp): block kernel from racing to send RST packets
- this makes PMTUD TCP reliable - this only works on kernels with the mark module - on kernels without the mark module, the icmp pmtud mtu found is used
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
||||||
RUN apk add wireguard-tools htop openssl tcpdump
|
RUN apk add wireguard-tools htop openssl tcpdump iptables
|
||||||
|
|||||||
+1
-1
@@ -13,7 +13,7 @@ FROM --platform=${BUILDPLATFORM} ghcr.io/qdm12/binpot:mockgen-${MOCKGEN_VERSION}
|
|||||||
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
|
FROM --platform=${BUILDPLATFORM} golang:${GO_VERSION}-alpine${GO_ALPINE_VERSION} AS base
|
||||||
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
|
COPY --from=xcputranslate /xcputranslate /usr/local/bin/xcputranslate
|
||||||
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
|
# Note: findutils needed to have xargs support `-d` flag for mocks stage.
|
||||||
RUN apk --update add git g++ findutils
|
RUN apk --update add git g++ findutils iptables
|
||||||
ENV CGO_ENABLED=0
|
ENV CGO_ENABLED=0
|
||||||
COPY --from=golangci-lint /bin /go/bin/golangci-lint
|
COPY --from=golangci-lint /bin /go/bin/golangci-lint
|
||||||
COPY --from=mockgen /bin /go/bin/mockgen
|
COPY --from=mockgen /bin /go/bin/mockgen
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ type chainRule struct {
|
|||||||
redirPorts []uint16 // Not specified if empty.
|
redirPorts []uint16 // Not specified if empty.
|
||||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||||
tcpFlags tcpFlags
|
tcpFlags tcpFlags
|
||||||
|
mark mark
|
||||||
|
}
|
||||||
|
|
||||||
|
type mark struct {
|
||||||
|
invert bool
|
||||||
|
value uint
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||||
@@ -278,6 +284,14 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
|||||||
i++
|
i++
|
||||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||||
i++
|
i++
|
||||||
|
case "mark":
|
||||||
|
i++
|
||||||
|
mark, consumed, err := parseMark(optionalFields[i:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing mark: %w", err)
|
||||||
|
}
|
||||||
|
rule.mark = mark
|
||||||
|
i += consumed
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||||
ErrChainRuleMalformed, optionalFields[i])
|
ErrChainRuleMalformed, optionalFields[i])
|
||||||
@@ -397,6 +411,32 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
|||||||
return ports, nil
|
return ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||||
|
|
||||||
|
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||||
|
switch optionalFields[consumed] {
|
||||||
|
case "match":
|
||||||
|
consumed++
|
||||||
|
if optionalFields[consumed] == "!" {
|
||||||
|
m.invert = true
|
||||||
|
consumed++
|
||||||
|
}
|
||||||
|
|
||||||
|
const base = 0 // auto-detect
|
||||||
|
const bits = 32
|
||||||
|
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||||
|
if err != nil {
|
||||||
|
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||||
|
}
|
||||||
|
m.value = uint(value)
|
||||||
|
consumed++
|
||||||
|
default:
|
||||||
|
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||||
|
ErrChainRuleMalformed, optionalFields[consumed])
|
||||||
|
}
|
||||||
|
return m, consumed, nil
|
||||||
|
}
|
||||||
|
|
||||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||||
|
|
||||||
func parseLineNumber(s string) (n uint16, err error) {
|
func parseLineNumber(s string) (n uint16, err error) {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type iptablesInstruction struct {
|
|||||||
toPorts []uint16 // if empty, there is no redirection
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
ctstate []string // if empty, there is no ctstate
|
ctstate []string // if empty, there is no ctstate
|
||||||
tcpFlags tcpFlags
|
tcpFlags tcpFlags
|
||||||
|
mark mark
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iptablesInstruction) setDefaults() {
|
func (i *iptablesInstruction) setDefaults() {
|
||||||
@@ -59,6 +60,8 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
|
|||||||
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
case !slices.Equal(i.tcpFlags.mask, rule.tcpFlags.mask) ||
|
||||||
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
!slices.Equal(i.tcpFlags.comparison, rule.tcpFlags.comparison):
|
||||||
return false
|
return false
|
||||||
|
case i.mark != rule.mark:
|
||||||
|
return false
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -100,7 +103,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
|
|
||||||
// All flags use one value after the flag, except the following:
|
// All flags use one value after the flag, except the following:
|
||||||
switch flag {
|
switch flag {
|
||||||
case "--tcp-flags":
|
case "--tcp-flags": // -m can have 1 or 2 values
|
||||||
const expected = 3
|
const expected = 3
|
||||||
if len(fields) < expected {
|
if len(fields) < expected {
|
||||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||||
@@ -130,7 +133,30 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
|||||||
instruction.target = value
|
instruction.target = value
|
||||||
case "-p", "--protocol":
|
case "-p", "--protocol":
|
||||||
instruction.protocol = value
|
instruction.protocol = value
|
||||||
case "-m", "--match": // ignore match
|
case "-m", "--match":
|
||||||
|
consumed = 2 // -m can have 1 or 2 values, so it consumes 2 or 3 fields.
|
||||||
|
switch value {
|
||||||
|
case "tcp", "udp": // for now ignore the protocol match since it's auto-loaded
|
||||||
|
case "mark":
|
||||||
|
switch fields[2] {
|
||||||
|
case "!":
|
||||||
|
consumed++
|
||||||
|
instruction.mark.invert = true
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||||
|
ErrIptablesCommandMalformed, fields[2])
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: unknown match value: %s", ErrIptablesCommandMalformed, value)
|
||||||
|
}
|
||||||
|
case "--mark":
|
||||||
|
const base = 0 // auto-detect
|
||||||
|
const bits = 32
|
||||||
|
value, err := strconv.ParseUint(value, base, bits)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing mark value %q: %w", fields[2], err)
|
||||||
|
}
|
||||||
|
instruction.mark.value = uint(value)
|
||||||
case "-i", "--in-interface":
|
case "-i", "--in-interface":
|
||||||
instruction.inputInterface = value
|
instruction.inputInterface = value
|
||||||
case "-o", "--out-interface":
|
case "-o", "--out-interface":
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tcpFlags struct {
|
type tcpFlags struct {
|
||||||
@@ -60,3 +63,35 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
|||||||
}
|
}
|
||||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||||
|
|
||||||
|
// TempDropOutputTCPRST temporarily drops outgoing TCP RST packets to the specified address and port,
|
||||||
|
// for any TCP packets not marked with the excludeMark given.
|
||||||
|
// This is necessary for TCP path MTU discovery to work, as the kernel will try to terminate the connection
|
||||||
|
// by sending a TCP RST packet, although we want to handle the connection manually.
|
||||||
|
func (c *Config) TempDropOutputTCPRST(ctx context.Context,
|
||||||
|
addrPort netip.AddrPort, excludeMark int) (
|
||||||
|
revert func(ctx context.Context) error, err error,
|
||||||
|
) {
|
||||||
|
_, err = os.Stat("/usr/lib/xtables/libxt_mark.so")
|
||||||
|
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("%w", ErrMarkMatchModuleMissing)
|
||||||
|
}
|
||||||
|
|
||||||
|
const template = "%s OUTPUT -p tcp -d %s --dport %d --tcp-flags RST RST -m mark ! --mark %d -j DROP" //nolint:dupword
|
||||||
|
instruction := fmt.Sprintf(template, "--append", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||||
|
revertInstruction := fmt.Sprintf(template, "--delete", addrPort.Addr(), addrPort.Port(), excludeMark)
|
||||||
|
run := c.runIptablesInstruction
|
||||||
|
if addrPort.Addr().Is6() {
|
||||||
|
run = c.runIP6tablesInstruction
|
||||||
|
}
|
||||||
|
revert = func(ctx context.Context) error {
|
||||||
|
return run(ctx, revertInstruction)
|
||||||
|
}
|
||||||
|
err = run(ctx, instruction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("running instruction: %w", err)
|
||||||
|
}
|
||||||
|
return revert, nil
|
||||||
|
}
|
||||||
|
|||||||
+15
-2
@@ -7,11 +7,14 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/firewall"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/icmp"
|
"github.com/qdm12/gluetun/internal/pmtud/icmp"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrPMTUDFailICMPAndTCP = errors.New("PMTUD failed with both ICMP and TCP")
|
||||||
|
|
||||||
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||||
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||||
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||||
@@ -23,7 +26,7 @@ import (
|
|||||||
// If the logger is nil, a no-op logger is used.
|
// If the logger is nil, a no-op logger is used.
|
||||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||||
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||||
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) (
|
physicalLinkMTU uint32, tryTimeout time.Duration, fw tcp.Firewall, logger Logger) (
|
||||||
mtu uint32, err error,
|
mtu uint32, err error,
|
||||||
) {
|
) {
|
||||||
if physicalLinkMTU == 0 {
|
if physicalLinkMTU == 0 {
|
||||||
@@ -67,13 +70,23 @@ func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []net
|
|||||||
const mtuMargin = 150
|
const mtuMargin = 150
|
||||||
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||||
}
|
}
|
||||||
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger)
|
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, fw, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, firewall.ErrMarkMatchModuleMissing) {
|
||||||
|
logger.Debugf("aborting TCP path MTU discovery: %s", err)
|
||||||
|
if icmpSuccess {
|
||||||
|
return maxPossibleMTU, nil // only rely on ICMP PMTUD results
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w", ErrPMTUDFailICMPAndTCP)
|
||||||
|
}
|
||||||
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
|
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
|
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
|
||||||
return mtu, nil
|
return mtu, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TCP PMTUD failed for all addresses for external reasons,
|
||||||
|
// so do not take the risk and return an error.
|
||||||
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
|
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,15 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Firewall interface {
|
||||||
|
TempDropOutputTCPRST(ctx context.Context, addrPort netip.AddrPort,
|
||||||
|
excludeMark int) (revert func(ctx context.Context) error, err error)
|
||||||
|
}
|
||||||
|
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
Debug(msg string)
|
Debug(msg string)
|
||||||
Debugf(msg string, args ...any)
|
Debugf(msg string, args ...any)
|
||||||
|
|||||||
@@ -19,7 +19,25 @@ type testUnit struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||||
minMTU, maxPossibleMTU uint32, logger Logger,
|
minMTU, maxPossibleMTU uint32, firewall Firewall, logger Logger,
|
||||||
|
) (mtu uint32, err error) {
|
||||||
|
const excludeMark = 4325
|
||||||
|
revert, err := firewall.TempDropOutputTCPRST(ctx, addrPort, excludeMark)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("temporarily dropping outgoing TCP RST packets: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := revert(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("reverting firewall changes: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, excludeMark, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||||
|
minMTU, maxPossibleMTU uint32, excludeMark int, logger Logger,
|
||||||
) (mtu uint32, err error) {
|
) (mtu uint32, err error) {
|
||||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||||
@@ -36,7 +54,7 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|||||||
if addrPort.Addr().Is6() {
|
if addrPort.Addr().Is6() {
|
||||||
family = constants.AF_INET6
|
family = constants.AF_INET6
|
||||||
}
|
}
|
||||||
fd, stop, err := startRawSocket(family)
|
fd, stop, err := startRawSocket(family, excludeMark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||||
}
|
}
|
||||||
@@ -80,8 +98,8 @@ func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
|||||||
if tests[i].ok {
|
if tests[i].ok {
|
||||||
stop()
|
stop()
|
||||||
cancel()
|
cancel()
|
||||||
return PathMTUDiscover(ctx, addrPort,
|
return pathMTUDiscover(ctx, addrPort,
|
||||||
tests[i].mtu, tests[i+1].mtu-1, logger)
|
tests[i].mtu, tests[i+1].mtu-1, excludeMark, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,18 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
|
func startRawSocket(family, excludeMark int) (fd fileDescriptor, stop func(), err error) {
|
||||||
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
fdPlatform, err := socket(family, constants.SOCK_RAW, constants.IPPROTO_TCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = setMark(fdPlatform, excludeMark)
|
||||||
|
if err != nil {
|
||||||
|
_ = closeSocket(fdPlatform)
|
||||||
|
return 0, nil, fmt.Errorf("setting mark option on raw socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if family == constants.AF_INET {
|
if family == constants.AF_INET {
|
||||||
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -2,6 +2,17 @@ package tcp
|
|||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
// setMark sets a mark on each packets sent through this socket.
|
||||||
|
// This is used in conjunction with iptables to block outgoing kernel automated
|
||||||
|
// RST packets, since the kernel is not aware of us handling the connection manually.
|
||||||
|
// For example:
|
||||||
|
// iptables -A OUTPUT -p tcp --tcp-flags RST RST -m mark ! --mark 123 -j DROP
|
||||||
|
//
|
||||||
|
//nolint:dupword
|
||||||
|
func setMark(fd, excludeMark int) error {
|
||||||
|
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, excludeMark)
|
||||||
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd int) error {
|
func setMTUDiscovery(fd int) error {
|
||||||
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
return unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
|
"github.com/qdm12/gluetun/internal/firewall"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"github.com/qdm12/gluetun/internal/routing"
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
@@ -22,9 +24,32 @@ import (
|
|||||||
func Test_runTest(t *testing.T) {
|
func Test_runTest(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
t.Skipf("temporarily skipping test")
|
serverAddrs := map[string]netip.AddrPort{
|
||||||
|
"cloudflare-http": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80),
|
||||||
|
"cloudflare-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443),
|
||||||
|
"google-https": netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443),
|
||||||
|
}
|
||||||
|
|
||||||
noopLogger := &noopLogger{}
|
noopLogger := &noopLogger{}
|
||||||
|
|
||||||
|
cmder := command.New()
|
||||||
|
fw, err := firewall.NewConfig(t.Context(), noopLogger, cmder, nil, nil)
|
||||||
|
if errors.Is(err, firewall.ErrIPTablesNotSupported) {
|
||||||
|
t.Skip("iptables not installed, skipping TCP PMTUD tests")
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "creating firewall config")
|
||||||
|
|
||||||
|
// Prevent Kernel from sending RST packets back to servers
|
||||||
|
const excludeMark = 4324
|
||||||
|
for _, addrPort := range serverAddrs {
|
||||||
|
revert, err := fw.TempDropOutputTCPRST(t.Context(), addrPort, excludeMark)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := revert(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
netlinker := netlink.New(noopLogger)
|
netlinker := netlink.New(noopLogger)
|
||||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||||
@@ -34,7 +59,7 @@ func Test_runTest(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(t.Context())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
const family = constants.AF_INET
|
const family = constants.AF_INET
|
||||||
fd, stop, err := startRawSocket(family)
|
fd, stop, err := startRawSocket(family, excludeMark)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
const ipv4 = true
|
const ipv4 = true
|
||||||
@@ -44,6 +69,8 @@ func Test_runTest(t *testing.T) {
|
|||||||
trackerCh <- tracker.listen(ctx)
|
trackerCh <- tracker.listen(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
const mtuSafetyBuffer = 200
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
stop()
|
stop()
|
||||||
cancel() // stop listening
|
cancel() // stop listening
|
||||||
@@ -72,30 +99,30 @@ func Test_runTest(t *testing.T) {
|
|||||||
dst: func(_ *testing.T) netip.AddrPort {
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||||
},
|
},
|
||||||
mtu: defaultIPv4MTU,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
},
|
},
|
||||||
"1.1.1.1:443": {
|
"1.1.1.1:443": {
|
||||||
timeout: time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
return serverAddrs["cloudflare-https"]
|
||||||
},
|
},
|
||||||
mtu: defaultIPv4MTU,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"1.1.1.1:80": {
|
"1.1.1.1:80": {
|
||||||
timeout: time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
return serverAddrs["cloudflare-http"]
|
||||||
},
|
},
|
||||||
mtu: defaultIPv4MTU,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
"8.8.8.8:443": {
|
"8.8.8.8:443": {
|
||||||
timeout: time.Second,
|
timeout: 5 * time.Second,
|
||||||
dst: func(_ *testing.T) netip.AddrPort {
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
|
return serverAddrs["google-https"]
|
||||||
},
|
},
|
||||||
mtu: defaultIPv4MTU,
|
mtu: defaultIPv4MTU - mtuSafetyBuffer,
|
||||||
success: true,
|
success: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -103,9 +130,11 @@ func Test_runTest(t *testing.T) {
|
|||||||
for name, testCase := range testCases {
|
for name, testCase := range testCases {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
dst := testCase.dst(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
dst := testCase.dst(t)
|
|
||||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||||
if testCase.success {
|
if testCase.success {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
|
func setMark(fd, excludeMark int) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd int) error {
|
func setMTUDiscovery(fd int) error {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from windows.Socka
|
|||||||
return windows.Recvfrom(windows.Handle(fd), p, flags)
|
return windows.Recvfrom(windows.Handle(fd), p, flags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setMark(fd windows.Handle, _ int) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func setMTUDiscovery(fd windows.Handle) error {
|
func setMTUDiscovery(fd windows.Handle) error {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||||
portforward "github.com/qdm12/gluetun/internal/portforward"
|
portforward "github.com/qdm12/gluetun/internal/portforward"
|
||||||
"github.com/qdm12/gluetun/internal/provider"
|
"github.com/qdm12/gluetun/internal/provider"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
@@ -17,6 +18,7 @@ type Firewall interface {
|
|||||||
SetVPNConnection(ctx context.Context, connection models.Connection, interfaceName string) error
|
SetVPNConnection(ctx context.Context, connection models.Connection, interfaceName string) error
|
||||||
SetAllowedPort(ctx context.Context, port uint16, interfaceName string) error
|
SetAllowedPort(ctx context.Context, port uint16, interfaceName string) error
|
||||||
RemoveAllowedPort(ctx context.Context, port uint16) error
|
RemoveAllowedPort(ctx context.Context, port uint16) error
|
||||||
|
tcp.Firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud"
|
"github.com/qdm12/gluetun/internal/pmtud"
|
||||||
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||||
"github.com/qdm12/gluetun/internal/version"
|
"github.com/qdm12/gluetun/internal/version"
|
||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
@@ -58,7 +59,7 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
|||||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
|
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
|
||||||
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
|
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
|
||||||
l.netLinker, l.routing, mtuLogger)
|
l.netLinker, l.routing, l.fw, mtuLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mtuLogger.Error(err.Error())
|
mtuLogger.Error(err.Error())
|
||||||
}
|
}
|
||||||
@@ -156,7 +157,7 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
|||||||
|
|
||||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||||
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||||
netlinker NetLinker, routing Routing, logger *log.Logger,
|
netlinker NetLinker, routing Routing, firewall tcp.Firewall, logger *log.Logger,
|
||||||
) error {
|
) error {
|
||||||
logger.Info("finding maximum MTU, this can take up to 6 seconds")
|
logger.Info("finding maximum MTU, this can take up to 6 seconds")
|
||||||
|
|
||||||
@@ -185,7 +186,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
|||||||
|
|
||||||
const pingTimeout = time.Second
|
const pingTimeout = time.Second
|
||||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
|
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
|
||||||
vpnLinkMTU, pingTimeout, logger)
|
vpnLinkMTU, pingTimeout, firewall, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
vpnLinkMTU = originalMTU
|
vpnLinkMTU = originalMTU
|
||||||
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||||
|
|||||||
Reference in New Issue
Block a user