mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
Path MTU discovery fixes and improvements (#3109)
- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used - New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443` - ICMP PMTUD now targets external-by-default IP addresses - New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism. - Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆 - Fix #3108
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
||||
RUN apk add wireguard-tools htop openssl
|
||||
RUN apk add wireguard-tools htop openssl tcpdump
|
||||
|
||||
@@ -22,6 +22,7 @@ linters:
|
||||
- "^disabled$"
|
||||
# Firewall and routing strings
|
||||
- "^(ACCEPT|DROP)$"
|
||||
- "^--append$"
|
||||
- "^--delete$"
|
||||
- "^all$"
|
||||
- "^(tcp|udp)$"
|
||||
|
||||
+4
-1
@@ -110,8 +110,11 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||
WIREGUARD_ADDRESSES= \
|
||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
WIREGUARD_MTU=1320 \
|
||||
WIREGUARD_MTU= \
|
||||
WIREGUARD_IMPLEMENTATION=auto \
|
||||
# PMTUD
|
||||
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
||||
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443 \
|
||||
# VPN server filtering
|
||||
SERVER_REGIONS= \
|
||||
SERVER_COUNTRIES= \
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
// PMTUD contains settings to configure Path MTU Discovery.
|
||||
type PMTUD struct {
|
||||
// ICMPAddresses is the redundancy list of addresses to use
|
||||
// for ICMP path MTU discovery. Each address MUST handle ICMP
|
||||
// packets for PMTUD to work.
|
||||
// It cannot be nil in the internal state.
|
||||
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
|
||||
// TCPAddresses is the redundancy list of addresses to use
|
||||
// for TCP path MTU discovery. Each address MUST have a listening
|
||||
// TCP server on the port specified.
|
||||
// It cannot be nil in the internal state.
|
||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||
)
|
||||
|
||||
// Validate validates PMTUD settings.
|
||||
func (p PMTUD) validate() (err error) {
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
for i, addr := range p.TCPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PMTUD) copy() (copied PMTUD) {
|
||||
return PMTUD{
|
||||
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
|
||||
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PMTUD) overrideWith(other PMTUD) {
|
||||
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
|
||||
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
|
||||
}
|
||||
|
||||
func (p *PMTUD) setDefaults() {
|
||||
defaultICMPAddresses := []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
|
||||
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
|
||||
}
|
||||
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
|
||||
|
||||
const tlsPort = 443
|
||||
defaultTCPAddresses := []netip.AddrPort{
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
||||
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
||||
}
|
||||
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
||||
}
|
||||
|
||||
func (p PMTUD) String() string {
|
||||
return p.toLinesNode().String()
|
||||
}
|
||||
|
||||
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Path MTU discovery:")
|
||||
|
||||
addrs := make([]string, len(p.ICMPAddresses))
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
addrs[i] = addr.String()
|
||||
}
|
||||
node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", "))
|
||||
|
||||
addrs = make([]string, len(p.TCPAddresses))
|
||||
for i, addr := range p.TCPAddresses {
|
||||
addrs[i] = addr.String()
|
||||
}
|
||||
node.Appendf("TCP addresses: %s", strings.Join(addrs, ", "))
|
||||
return node
|
||||
}
|
||||
|
||||
func (p *PMTUD) read(r *reader.Reader) (err error) {
|
||||
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -29,14 +29,17 @@ func Test_Settings_String(t *testing.T) {
|
||||
| | └── OpenVPN server selection settings:
|
||||
| | ├── Protocol: UDP
|
||||
| | └── Private Internet Access encryption preset: strong
|
||||
| └── OpenVPN settings:
|
||||
| ├── OpenVPN version: 2.6
|
||||
| ├── User: [not set]
|
||||
| ├── Password: [not set]
|
||||
| ├── Private Internet Access encryption preset: strong
|
||||
| ├── Network interface: tun0
|
||||
| ├── Run OpenVPN as: root
|
||||
| └── Verbosity level: 1
|
||||
| ├── OpenVPN settings:
|
||||
| | ├── OpenVPN version: 2.6
|
||||
| | ├── User: [not set]
|
||||
| | ├── Password: [not set]
|
||||
| | ├── Private Internet Access encryption preset: strong
|
||||
| | ├── Network interface: tun0
|
||||
| | ├── Run OpenVPN as: root
|
||||
| | └── Verbosity level: 1
|
||||
| └── Path MTU discovery:
|
||||
| ├── ICMP addresses: 1.1.1.1, 8.8.8.8
|
||||
| └── TCP addresses: 1.1.1.1:443, 8.8.8.8:443
|
||||
├── DNS settings:
|
||||
| ├── Keep existing nameserver(s): no
|
||||
| ├── DNS server address to use: 127.0.0.1
|
||||
|
||||
@@ -18,6 +18,7 @@ type VPN struct {
|
||||
Provider Provider `json:"provider"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
PMTUD PMTUD `json:"pmtud"`
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
@@ -45,6 +46,11 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
}
|
||||
}
|
||||
|
||||
err = v.PMTUD.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,6 +60,7 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
Provider: v.Provider.copy(),
|
||||
OpenVPN: v.OpenVPN.copy(),
|
||||
Wireguard: v.Wireguard.copy(),
|
||||
PMTUD: v.PMTUD.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +69,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
||||
v.Provider.overrideWith(other.Provider)
|
||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||
v.Wireguard.overrideWith(other.Wireguard)
|
||||
v.PMTUD.overrideWith(other.PMTUD)
|
||||
}
|
||||
|
||||
func (v *VPN) setDefaults() {
|
||||
@@ -69,6 +77,7 @@ func (v *VPN) setDefaults() {
|
||||
v.Provider.setDefaults()
|
||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||
v.Wireguard.setDefaults(v.Provider.Name)
|
||||
v.PMTUD.setDefaults()
|
||||
}
|
||||
|
||||
func (v VPN) String() string {
|
||||
@@ -85,6 +94,7 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
||||
} else {
|
||||
node.AppendNode(v.Wireguard.toLinesNode())
|
||||
}
|
||||
node.AppendNode(v.PMTUD.toLinesNode())
|
||||
|
||||
return node
|
||||
}
|
||||
@@ -107,5 +117,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
||||
return fmt.Errorf("wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = v.PMTUD.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PMTUD: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,15 +38,9 @@ type Wireguard struct {
|
||||
Interface string `json:"interface"`
|
||||
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||
// It cannot be zero in the internal state, and defaults to
|
||||
// 1320. Note it is not the wireguard-go MTU default of 1420
|
||||
// because this impacts bandwidth a lot on some VPN providers,
|
||||
// see https://github.com/qdm12/gluetun/issues/1650.
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue:
|
||||
// https://github.com/qdm12/gluetun/issues/2533.
|
||||
// Note this should now be replaced with the PMTUD feature.
|
||||
MTU uint32 `json:"mtu"`
|
||||
// It cannot be nil in the internal state, and defaults to
|
||||
// 0 indicating to use PMTUD.
|
||||
MTU *uint32 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
// It defaults to "auto" and cannot be the empty string
|
||||
@@ -195,8 +189,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||
const defaultMTU = 1320
|
||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
||||
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||
}
|
||||
|
||||
@@ -232,7 +225,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||
if *w.MTU == 0 {
|
||||
interfaceNode.Append("MTU: use path MTU discovery")
|
||||
} else {
|
||||
interfaceNode.Appendf("MTU: %d", *w.MTU)
|
||||
}
|
||||
|
||||
if w.Implementation != "auto" {
|
||||
node.Appendf("Implementation: %s", w.Implementation)
|
||||
@@ -273,11 +270,9 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
|
||||
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if mtuPtr != nil {
|
||||
w.MTU = *mtuPtr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,17 +29,16 @@ func appendOrDelete(remove bool) string {
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(rule, "-A"):
|
||||
return strings.Replace(rule, "-A", "-D", 1)
|
||||
case strings.HasPrefix(rule, "--append"):
|
||||
return strings.Replace(rule, "--append", "-D", 1)
|
||||
case strings.HasPrefix(rule, "-D"):
|
||||
return strings.Replace(rule, "-D", "-A", 1)
|
||||
case strings.HasPrefix(rule, "--delete"):
|
||||
return strings.Replace(rule, "--delete", "-A", 1)
|
||||
fields := strings.Fields(rule)
|
||||
for i, field := range fields {
|
||||
switch field {
|
||||
case "-A", "--append":
|
||||
fields[i] = "--delete"
|
||||
case "-D", "--delete":
|
||||
fields[i] = "--append"
|
||||
}
|
||||
return rule
|
||||
}
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables.
|
||||
@@ -86,11 +85,15 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
}
|
||||
|
||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||
tables := []string{"filter"}
|
||||
for _, table := range tables {
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
"-t " + table + " --flush", // flush all chains
|
||||
"-t " + table + " --delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
|
||||
@@ -18,6 +18,7 @@ type Route struct {
|
||||
Type uint8
|
||||
Scope uint8
|
||||
Proto uint8
|
||||
AdvMSS uint32
|
||||
}
|
||||
|
||||
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
@@ -35,6 +36,9 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||
r.Type = message.Type
|
||||
r.Scope = message.Scope
|
||||
r.Proto = message.Protocol
|
||||
if metrics := message.Attributes.Metrics; metrics != nil {
|
||||
r.AdvMSS = metrics.AdvMSS
|
||||
}
|
||||
}
|
||||
|
||||
func (r Route) message() *rtnetlink.RouteMessage {
|
||||
@@ -58,7 +62,6 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
||||
Protocol: r.Proto,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
OutIface: r.LinkIndex,
|
||||
Dst: *dst, // there should always be a dst for routes
|
||||
Gateway: netipAddrToNetIP(r.Gw),
|
||||
Priority: r.Priority,
|
||||
Table: extendedTable,
|
||||
@@ -67,6 +70,15 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
||||
if src != nil { // src is optional
|
||||
message.Attributes.Src = *src
|
||||
}
|
||||
if dst != nil {
|
||||
message.Attributes.Dst = *dst
|
||||
}
|
||||
if r.AdvMSS != 0 {
|
||||
if message.Attributes.Metrics == nil {
|
||||
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
|
||||
}
|
||||
message.Attributes.Metrics.AdvMSS = r.AdvMSS
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
MaxEthernetFrameSize uint32 = 1500
|
||||
// MinIPv4MTU is defined according to
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
MinIPv4MTU uint32 = 68
|
||||
MinIPv6MTU uint32 = 1280
|
||||
|
||||
IPv4HeaderLength uint32 = 20
|
||||
IPv6HeaderLength uint32 = 40
|
||||
UDPHeaderLength uint32 = 8
|
||||
// BaseTCPHeaderLength is the TCP header length without options,
|
||||
// which is the minimum TCP header length.
|
||||
BaseTCPHeaderLength uint32 = 20
|
||||
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
|
||||
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
|
||||
MaxTCPHeaderLength uint32 = 60
|
||||
WireguardHeaderLength uint32 = 32
|
||||
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
|
||||
8 + // session id
|
||||
4 + // packet id
|
||||
28 // max possible auth tag/iv
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
||||
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
||||
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrICMPEchoDataMismatch, len(sent), len(received))
|
||||
ErrEchoDataMismatch, len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrICMPEchoDataMismatch, sent, received)
|
||||
ErrEchoDataMismatch, sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
// setDontFragment for platforms other than Linux and Windows
|
||||
// is not implemented, so we just return assuming the don't
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
@@ -1,6 +1,4 @@
|
||||
//go:build windows
|
||||
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
@@ -0,0 +1,30 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrMTUNotFound = errors.New("MTU not found")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// PathMTUDiscover discovers the path MTU to the given IP address
|
||||
// using ICMP.
|
||||
// It first tries to get the next hop MTU using ICMP messages.
|
||||
// If that fails, it falls back to sending echo requests with
|
||||
// different packet sizes to find the maximum MTU.
|
||||
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
if ip.Is4() {
|
||||
logger.Debug("finding IPv4 next hop MTU")
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debug("falling back to sending different sized echo packets")
|
||||
minMTU := constants.MinIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = constants.MinIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package icmp
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,14 +11,13 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
minIPv4MTU uint32 = 68
|
||||
icmpv4Protocol int = 1
|
||||
icmpv4Protocol = 1
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
@@ -38,7 +37,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,7 +82,9 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop in case we read an echo reply for another ICMP request
|
||||
// for loop in case we read an ICMP message from another ICMP request
|
||||
// or TCP/UDP traffic triggering an ICMP response.
|
||||
for {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
@@ -108,24 +109,27 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.DstUnreach:
|
||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||
const portUnreachable = 3
|
||||
const communicationAdministrativelyProhibitedCode = 13
|
||||
switch inboundMessage.Code {
|
||||
case fragmentationRequiredAndDFFlagSetCode:
|
||||
case portUnreachable: // triggered by TCP or UDP from applications
|
||||
continue // ignore and wait for the next message
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrICMPDestinationUnreachable,
|
||||
ErrICMPCommunicationAdministrativelyProhibited,
|
||||
ErrDestinationUnreachable,
|
||||
ErrCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
||||
ErrDestinationUnreachable, inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
// Note: the go library does not handle this NextHopMTU section.
|
||||
nextHopMTU := packetBytes[6:8]
|
||||
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
||||
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||
}
|
||||
@@ -153,7 +157,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
minIPv6MTU = 1280
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
case *icmp.PacketTooBig:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
||||
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
||||
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
@@ -116,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package icmp
|
||||
|
||||
import (
|
||||
cryptorand "crypto/rand"
|
||||
@@ -0,0 +1,187 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
type icmpTestUnit struct {
|
||||
mtu uint32
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU uint32, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]icmpTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []icmpTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
echoBody, ok := message.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
|
||||
ipHeader := make([]byte, constants.IPv4HeaderLength)
|
||||
const version byte = 4
|
||||
const headerLength byte = 20 / 4 // in 32-bit words
|
||||
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
|
||||
ipHeader[1] = 0 // type of Service
|
||||
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
|
||||
ipHeader[4], ipHeader[5] = 0, 0 // identification
|
||||
const flagsAndOffset uint16 = 0x4000 // DF bit set
|
||||
putUint16(ipHeader[6:], flagsAndOffset)
|
||||
ipHeader[8] = 64 // ttl
|
||||
ipHeader[9] = syscall.IPPROTO_TCP
|
||||
srcIPBytes := srcIP.As4()
|
||||
copy(ipHeader[12:16], srcIPBytes[:])
|
||||
dstIPBytes := dstIP.As4()
|
||||
copy(ipHeader[16:20], dstIPBytes[:])
|
||||
|
||||
checksum := ipChecksum(ipHeader)
|
||||
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
|
||||
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
|
||||
|
||||
return ipHeader
|
||||
}
|
||||
|
||||
// ipChecksum calculates the checksum for the IP header.
|
||||
//
|
||||
//nolint:mnd
|
||||
func ipChecksum(header []byte) uint16 {
|
||||
sum := uint32(0)
|
||||
for i := 0; i < len(header)-1; i += 2 {
|
||||
sum += uint32(header[i])<<8 + uint32(header[i+1])
|
||||
}
|
||||
if len(header)%2 != 0 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum) //nolint:gosec
|
||||
}
|
||||
|
||||
// HeaderV6 makes an IPv6 header.
|
||||
// payloadLen is the length of the payload following the header.
|
||||
// nextHeader can be byte([syscall.IPPROTO_TCP]) for example.
|
||||
func HeaderV6(srcIP, dstIP netip.Addr,
|
||||
payloadLen uint16, nextHeader byte,
|
||||
) []byte {
|
||||
ipv6Header := make([]byte, constants.IPv6HeaderLength)
|
||||
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
|
||||
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
|
||||
|
||||
// Flow Label (remaining 16 bits)
|
||||
ipv6Header[2] = 0x00
|
||||
ipv6Header[3] = 0x00
|
||||
|
||||
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
|
||||
ipv6Header[6] = nextHeader
|
||||
const hopLimit = 64
|
||||
ipv6Header[7] = hopLimit
|
||||
copy(ipv6Header[8:24], srcIP.AsSlice())
|
||||
copy(ipv6Header[24:40], dstIP.AsSlice())
|
||||
return ipv6Header
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
func putUint16(b []byte, v uint16) {
|
||||
binary.NativeEndian.PutUint16(b, v)
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build !darwin
|
||||
|
||||
package ip
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func putUint16(b []byte, v uint16) {
|
||||
binary.BigEndian.PutUint16(b, v)
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package ip
|
||||
|
||||
import "syscall"
|
||||
|
||||
func SetIPv4HeaderIncluded(fd int) error {
|
||||
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package ip
|
||||
|
||||
func SetIPv4HeaderIncluded(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func SetIPv4HeaderIncluded(handle syscall.Handle) error {
|
||||
const ipHdrIncluded = windows.IP_HDRINCL
|
||||
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, ipHdrIncluded, 1)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package ip
|
||||
|
||||
func SetIPv6HeaderIncluded(fd int) error {
|
||||
panic("darwin does not allow an application to build IPv6 headers")
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package ip
|
||||
|
||||
import "syscall"
|
||||
|
||||
func SetIPv6HeaderIncluded(fd int) error {
|
||||
const ipv6HdrIncluded = 36 // IPV6_HDRINCL
|
||||
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, ipv6HdrIncluded, 1)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package ip
|
||||
|
||||
func SetIPv6HeaderIncluded(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package ip
|
||||
|
||||
import "syscall"
|
||||
|
||||
func SetIPv6HeaderIncluded(fd syscall.Handle) error {
|
||||
panic("windows does not allow an application to build IPv6 headers")
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
)
|
||||
|
||||
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
|
||||
// specified destination. It also reserves an ephemeral source port for the specified protocol
|
||||
// to ensure that the port is not used by other processes. The cleanup function returned should
|
||||
// be called to release the reserved port when done.
|
||||
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
|
||||
srcAddr, err := srcIP(dst.Addr())
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
|
||||
}
|
||||
|
||||
srcPort, cleanup, err := srcPort(srcAddr, proto)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
|
||||
}
|
||||
|
||||
var errNoRoute = fmt.Errorf("no route to destination")
|
||||
|
||||
func srcIP(dst netip.Addr) (netip.Addr, error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
family := uint8(syscall.AF_INET)
|
||||
if dst.Is6() {
|
||||
family = syscall.AF_INET6
|
||||
}
|
||||
|
||||
// Request route to destination
|
||||
requestMessage := &rtnetlink.RouteMessage{
|
||||
Family: family,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
Dst: dst.AsSlice(),
|
||||
},
|
||||
}
|
||||
messages, err := conn.Route.Get(requestMessage)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Attributes.Src == nil {
|
||||
continue
|
||||
}
|
||||
ipv6 := message.Attributes.Src.To4() == nil
|
||||
if ipv6 {
|
||||
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
|
||||
}
|
||||
|
||||
// srcPort reserves an ephemeral source port by opening a socket for the
|
||||
// protocol specified and binds it to the provided source address.
|
||||
// It doesn't actually listen on the port.
|
||||
// The cleanup function returned should be called to release the port when done.
|
||||
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
|
||||
family := syscall.AF_INET
|
||||
if srcAddr.Is6() {
|
||||
family = syscall.AF_INET6
|
||||
}
|
||||
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_STREAM, proto)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
|
||||
}
|
||||
cleanup = func() {
|
||||
_ = syscall.Close(fd)
|
||||
}
|
||||
|
||||
// Bind to port 0 to get an ephemeral port
|
||||
const port = 0
|
||||
var bindAddr syscall.Sockaddr
|
||||
if srcAddr.Is4() {
|
||||
bindAddr = &syscall.SockaddrInet4{
|
||||
Port: port,
|
||||
Addr: srcAddr.As4(),
|
||||
}
|
||||
} else {
|
||||
bindAddr = &syscall.SockaddrInet6{
|
||||
Port: port,
|
||||
Addr: srcAddr.As16(),
|
||||
}
|
||||
}
|
||||
|
||||
err = syscall.Bind(fd, bindAddr)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
|
||||
}
|
||||
|
||||
sockAddr, err := syscall.Getsockname(fd)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return 0, nil, fmt.Errorf("getting bound socket name: %w", err)
|
||||
}
|
||||
|
||||
switch typedSockAddr := sockAddr.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
|
||||
case *syscall.SockaddrInet6:
|
||||
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
|
||||
}
|
||||
return srcPort, cleanup, nil
|
||||
}
|
||||
+37
-232
@@ -4,268 +4,73 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/icmp"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||
)
|
||||
|
||||
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
||||
|
||||
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
||||
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
|
||||
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
|
||||
// whole range of possible MTU values up to the physical link MTU to check.
|
||||
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||
// If the pingTimeout is zero, it defaults to 1 second.
|
||||
// If the logger is nil, a no-op logger is used.
|
||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
|
||||
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) (
|
||||
mtu uint32, err error,
|
||||
) {
|
||||
if physicalLinkMTU == 0 {
|
||||
const ethernetStandardMTU = 1500
|
||||
physicalLinkMTU = ethernetStandardMTU
|
||||
}
|
||||
if pingTimeout == 0 {
|
||||
pingTimeout = time.Second
|
||||
if tryTimeout == 0 {
|
||||
tryTimeout = time.Second
|
||||
}
|
||||
if logger == nil {
|
||||
logger = &noopLogger{}
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
logger.Debug("finding IPv4 next hop MTU")
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
// Try finding the MTU using ICMP
|
||||
maxPossibleMTU := physicalLinkMTU
|
||||
icmpSuccess := false
|
||||
for _, icmpIP := range icmpAddrs {
|
||||
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
|
||||
tryTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
||||
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
|
||||
icmpSuccess = true
|
||||
maxPossibleMTU = mtu
|
||||
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
|
||||
logger.Debugf("ICMP path MTU discovery failed: %s", err)
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debug("falling back to sending different sized echo packets")
|
||||
minMTU := minIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = minIPv6MTU
|
||||
for _, addrPort := range tcpAddrs {
|
||||
minMTU := constants.MinIPv4MTU
|
||||
if addrPort.Addr().Is6() {
|
||||
minMTU = constants.MinIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
||||
}
|
||||
|
||||
type pmtudTestUnit struct {
|
||||
mtu uint32
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU uint32, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
if icmpSuccess {
|
||||
const mtuMargin = 150
|
||||
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||
}
|
||||
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]pmtudTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// Create the MTU slice of length 11 such that:
|
||||
// - the first element is the minMTU
|
||||
// - the last element is the maxMTU
|
||||
// - elements in-between are separated as close to each other
|
||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||
// with a total search space of 1728 MTUs which is enough;
|
||||
// to find it in 2 searches requires 37 parallel queries which
|
||||
// could be blocked by firewalls.
|
||||
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||
const mtusLength = 11 // find the final MTU in 3 searches
|
||||
diff := maxMTU - minMTU
|
||||
switch {
|
||||
case minMTU > maxMTU:
|
||||
panic("minMTU > maxMTU")
|
||||
case diff <= mtusLength:
|
||||
mtus = make([]uint32, 0, diff)
|
||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||
mtus = append(mtus, mtu)
|
||||
}
|
||||
default:
|
||||
step := float64(diff) / float64(mtusLength-1)
|
||||
mtus = make([]uint32, 0, mtusLength)
|
||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||
}
|
||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||
}
|
||||
|
||||
return mtus
|
||||
}
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []pmtudTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
echoBody, ok := message.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
|
||||
return mtu, nil
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package tcp
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package tcp
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||
@@ -0,0 +1,80 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
|
||||
|
||||
// Package tcp is a generated GoMock package.
|
||||
package tcp
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Warnf mocks base method.
|
||||
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Warnf", varargs...)
|
||||
}
|
||||
|
||||
// Warnf indicates an expected call of Warnf.
|
||||
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||
)
|
||||
|
||||
var ErrMTUNotFound = errors.New("MTU not found")
|
||||
|
||||
type testUnit struct {
|
||||
mtu uint32
|
||||
ok bool
|
||||
}
|
||||
|
||||
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||
minMTU, maxPossibleMTU uint32, logger Logger,
|
||||
) (mtu uint32, err error) {
|
||||
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]testUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = testUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
family := syscall.AF_INET
|
||||
if addrPort.Addr().Is6() {
|
||||
family = syscall.AF_INET6
|
||||
}
|
||||
fd, stop, err := startRawSocket(family)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
tracker := newTracker(fd, addrPort.Addr().Is4())
|
||||
|
||||
const timeout = time.Second
|
||||
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- tracker.listen(runCtx)
|
||||
}()
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
for i := range tests {
|
||||
go func(i int) {
|
||||
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
|
||||
tests[i].ok = err == nil
|
||||
doneCh <- struct{}{}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for range tests {
|
||||
select {
|
||||
case <-doneCh:
|
||||
case err := <-errCh:
|
||||
if err == nil { // timeout
|
||||
break
|
||||
}
|
||||
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if tests[len(tests)-1].ok {
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
}
|
||||
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
stop()
|
||||
cancel()
|
||||
return PathMTUDiscover(ctx, addrPort,
|
||||
tests[i].mtu, tests[i+1].mtu-1, logger)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
|
||||
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
|
||||
// However, in some cases where the server closes the connection with RST immediately,
|
||||
// it can be useful to add some data payload to a SYN packet and check if the server still
|
||||
// replies. Only set mtu to a non zero value if you know what you are doing.
|
||||
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
|
||||
seq = rand.Uint32() //nolint:gosec
|
||||
const ack = 0 // SYN has no ACK number
|
||||
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||
if mtu > 0 {
|
||||
payloadLength = getPayloadLength(mtu, dst)
|
||||
}
|
||||
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
|
||||
}
|
||||
|
||||
// createACKPacket creates a TCP ACK packet.
|
||||
// If the mtu is set to 0, no payload is sent.
|
||||
// Otherwise, the payload is calculated to test the MTU given.
|
||||
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
|
||||
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||
if mtu > 0 {
|
||||
payloadLength = getPayloadLength(mtu, dst)
|
||||
}
|
||||
const flags = ackFlag | pshFlag
|
||||
return createPacket(src, dst, seq, ack, payloadLength, flags)
|
||||
}
|
||||
|
||||
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
|
||||
const payloadLength = constants.BaseTCPHeaderLength // no data payload
|
||||
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
|
||||
}
|
||||
|
||||
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
|
||||
var ipHeaderLength uint32
|
||||
if dst.Addr().Is4() {
|
||||
ipHeaderLength = constants.IPv4HeaderLength
|
||||
} else {
|
||||
ipHeaderLength = constants.IPv6HeaderLength
|
||||
}
|
||||
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
|
||||
panic("MTU too small to hold IP and TCP headers")
|
||||
}
|
||||
return mtu - ipHeaderLength
|
||||
}
|
||||
|
||||
func createPacket(src, dst netip.AddrPort,
|
||||
seq, ack, payloadLength uint32, flags byte,
|
||||
) []byte {
|
||||
if payloadLength < constants.BaseTCPHeaderLength {
|
||||
panic("payload length is too small to hold TCP header")
|
||||
}
|
||||
|
||||
var ipHeader []byte
|
||||
if dst.Addr().Is4() {
|
||||
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
|
||||
} else {
|
||||
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
|
||||
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
|
||||
}
|
||||
|
||||
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
|
||||
|
||||
// data is just zeroes
|
||||
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
|
||||
var data []byte
|
||||
if dataLength > 0 {
|
||||
data = make([]byte, dataLength)
|
||||
}
|
||||
checksum := tcpChecksum(ipHeader, tcpHeader, data)
|
||||
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
|
||||
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
|
||||
|
||||
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
|
||||
copy(packet, ipHeader)
|
||||
copy(packet[len(ipHeader):], tcpHeader)
|
||||
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
|
||||
return packet
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||
)
|
||||
|
||||
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
|
||||
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
||||
}
|
||||
|
||||
if family == syscall.AF_INET {
|
||||
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
||||
} else {
|
||||
err = ip.SetIPv6HeaderIncluded(fdPlatform)
|
||||
}
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
|
||||
}
|
||||
|
||||
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
||||
err = setMTUDiscovery(fdPlatform)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
|
||||
}
|
||||
|
||||
// use polling because some Linux systems do not cancel
|
||||
// blocking syscalls such as recvfrom when the socket is closed,
|
||||
// which would cause things to hang indefinitely.
|
||||
err = setNonBlock(fdPlatform)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
|
||||
}
|
||||
|
||||
stop = func() {
|
||||
_ = syscall.Close(fdPlatform)
|
||||
}
|
||||
return fileDescriptor(fdPlatform), stop, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
|
||||
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
|
||||
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
|
||||
)
|
||||
|
||||
// Craft and send a raw TCP packet to test the MTU.
|
||||
// It expects either an RST reply (if no server is listening)
|
||||
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||
func runTest(ctx context.Context, fd fileDescriptor,
|
||||
tracker *tracker, dst netip.AddrPort, mtu uint32,
|
||||
) error {
|
||||
const proto = syscall.IPPROTO_TCP
|
||||
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting source address: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
ch := make(chan []byte)
|
||||
abort := make(chan struct{})
|
||||
defer close(abort)
|
||||
tracker.register(src.Port(), dst.Port(), ch, abort)
|
||||
defer tracker.unregister(src.Port(), dst.Port())
|
||||
|
||||
dstSockAddr := makeSockAddr(dst)
|
||||
|
||||
synPacket, synSeq := createSYNPacket(src, dst, 0)
|
||||
const sendToFlags = 0
|
||||
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending SYN packet: %w", err)
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
switch {
|
||||
case err != nil:
|
||||
return fmt.Errorf("parsing first reply TCP header: %w", err)
|
||||
case packetType == packetTypeRST:
|
||||
// server actively closed the connection, try sending a SYN with data
|
||||
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
||||
case packetType != packetTypeSYNACK:
|
||||
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
|
||||
case synAckAck != synSeq+1:
|
||||
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
|
||||
}
|
||||
|
||||
// Send a no-data ACK packet to finish the 3-way handshake.
|
||||
const ackMTU = 0 // no data payload initially
|
||||
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
|
||||
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending ACK-without-data packet: %w", err)
|
||||
}
|
||||
|
||||
// Send a data ACK packet to test the MTU given.
|
||||
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
|
||||
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending ACK-with-data packet: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing second reply TCP header: %w", err)
|
||||
}
|
||||
|
||||
switch packetType { //nolint:exhaustive
|
||||
case packetTypeRST:
|
||||
return nil
|
||||
case packetTypeACK:
|
||||
err = sendRST(fd, src, dst, ack)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending RST packet: %w", err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
_ = sendRST(fd, src, dst, ack)
|
||||
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
|
||||
}
|
||||
}
|
||||
|
||||
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
|
||||
if addr.Addr().Is4() {
|
||||
return &syscall.SockaddrInet4{
|
||||
Port: int(addr.Port()),
|
||||
Addr: addr.Addr().As4(),
|
||||
}
|
||||
}
|
||||
return &syscall.SockaddrInet6{
|
||||
Port: int(addr.Port()),
|
||||
Addr: addr.Addr().As16(),
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
|
||||
|
||||
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||
src, dst netip.AddrPort, mtu uint32,
|
||||
) error {
|
||||
packet, _ := createSYNPacket(src, dst, mtu)
|
||||
const sendToFlags = 0
|
||||
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // timeout: the MTU test SYN packet was too big
|
||||
case reply = <-ch:
|
||||
}
|
||||
|
||||
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing reply TCP header: %w", err)
|
||||
} else if packetType != packetTypeRST {
|
||||
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
|
||||
previousACK uint32,
|
||||
) error {
|
||||
seq := previousACK
|
||||
const ack = 0
|
||||
rstPacket := createRSTPacket(src, dst, seq, ack)
|
||||
const sendToFlags = 0
|
||||
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package tcp
|
||||
|
||||
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||
return reply, true
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package tcp
|
||||
|
||||
import "syscall"
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
//go:build !darwin
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||
if len(reply) < int(constants.IPv4HeaderLength) {
|
||||
return nil, false // not an IPv4 packet
|
||||
}
|
||||
|
||||
version := reply[0] >> 4 //nolint:mnd
|
||||
const ipv4Version = 4
|
||||
if version != ipv4Version {
|
||||
return nil, false
|
||||
}
|
||||
// For IPv4 we need to skip the IP header, which is at least
|
||||
// 20B and can be up to 60B.
|
||||
// The Internet Header Length is the lower 4 bits of the first byte and
|
||||
// represents the number of 32-bit words of the header length.
|
||||
const ihlMask byte = 0x0F
|
||||
const bytesInWord = 4
|
||||
headerLength := int((reply[0] & ihlMask)) * bytesInWord
|
||||
if len(reply) < headerLength {
|
||||
return nil, false // not enough data for full IPv4 header
|
||||
}
|
||||
return reply[headerLength:], true
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_runTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noopLogger := &noopLogger{}
|
||||
netlinker := netlink.New(noopLogger)
|
||||
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
||||
require.NoError(t, err, "finding default IPv4 route MTU")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
const family = syscall.AF_INET
|
||||
fd, stop, err := startRawSocket(family)
|
||||
require.NoError(t, err)
|
||||
|
||||
const ipv4 = true
|
||||
tracker := newTracker(fd, ipv4)
|
||||
trackerCh := make(chan error)
|
||||
go func() {
|
||||
trackerCh <- tracker.listen(ctx)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
stop()
|
||||
cancel() // stop listening
|
||||
err = <-trackerCh
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
testCases := map[string]struct {
|
||||
timeout time.Duration
|
||||
dst func(t *testing.T) netip.AddrPort
|
||||
mtu uint32
|
||||
success bool
|
||||
}{
|
||||
"local_not_listening": {
|
||||
timeout: time.Hour,
|
||||
dst: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
port := reserveClosedPort(t)
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
|
||||
},
|
||||
mtu: loopbackMTU,
|
||||
success: true,
|
||||
},
|
||||
"remote_not_listening": {
|
||||
timeout: 50 * time.Millisecond,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
},
|
||||
"1.1.1.1:443": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
"1.1.1.1:80": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
"8.8.8.8:443": {
|
||||
timeout: time.Second,
|
||||
dst: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
|
||||
},
|
||||
mtu: defaultIPv4MTU,
|
||||
success: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||
defer cancel()
|
||||
dst := testCase.dst(t)
|
||||
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||
if testCase.success {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var errRouteNotFound = errors.New("route not found")
|
||||
|
||||
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting routes list: %w", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
||||
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by index: %w", err)
|
||||
}
|
||||
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
||||
// or the IP header 16 bits will fail to fit that packet length value.
|
||||
const maxMTU = 65535
|
||||
return min(link.MTU, maxMTU), nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||
noopLogger := &noopLogger{}
|
||||
routing := routing.New(netlinker, noopLogger)
|
||||
defaultRoutes, err := routing.DefaultRoutes()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||
}
|
||||
for _, route := range defaultRoutes {
|
||||
if route.Family != netlink.FamilyV4 {
|
||||
continue
|
||||
}
|
||||
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||
}
|
||||
return link.MTU, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||
}
|
||||
|
||||
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||
t.Helper()
|
||||
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := syscall.Close(fd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
addr := &syscall.SockaddrInet4{
|
||||
Port: 0,
|
||||
Addr: [4]byte{127, 0, 0, 1},
|
||||
}
|
||||
|
||||
err = syscall.Bind(fd, addr)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr, err := syscall.Getsockname(fd)
|
||||
if err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
|
||||
if !ok {
|
||||
_ = syscall.Close(fd)
|
||||
t.Fatal("not an IPv4 address")
|
||||
}
|
||||
|
||||
return uint16(sockAddr4.Port) //nolint:gosec
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Patch(_ ...log.Option) {}
|
||||
func (l *noopLogger) Debug(_ string) {}
|
||||
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (l *noopLogger) Info(_ string) {}
|
||||
func (l *noopLogger) Warn(_ string) {}
|
||||
func (l *noopLogger) Error(_ string) {}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fileDescriptor is a platform-independent type for socket file descriptors.
|
||||
type fileDescriptor int
|
||||
|
||||
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||
return syscall.Sendto(int(fd), p, flags, to)
|
||||
}
|
||||
|
||||
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
|
||||
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
|
||||
}
|
||||
|
||||
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||
return syscall.Recvfrom(int(fd), p, flags)
|
||||
}
|
||||
|
||||
func setNonBlock(fd int) error {
|
||||
return syscall.SetNonblock(fd, true)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package tcp
|
||||
|
||||
func setMTUDiscovery(fd int) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type fileDescriptor syscall.Handle
|
||||
|
||||
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
|
||||
}
|
||||
|
||||
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||
timeval := int(timeout.Milliseconds())
|
||||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
|
||||
}
|
||||
|
||||
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
|
||||
}
|
||||
|
||||
func setMTUDiscovery(fd syscall.Handle) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func setNonBlock(fd syscall.Handle) error {
|
||||
// Windows: Use ioctlsocket with FIONBIO
|
||||
var arg uint32 = 1 // 1 to enable non-blocking mode
|
||||
var bytesReturned uint32
|
||||
const FIONBIO = 0x8004667e
|
||||
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
|
||||
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// For SYN, ack is 0.
|
||||
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
|
||||
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
|
||||
header := make([]byte, constants.BaseTCPHeaderLength)
|
||||
binary.BigEndian.PutUint16(header[0:], srcPort)
|
||||
binary.BigEndian.PutUint16(header[2:], dstPort)
|
||||
binary.BigEndian.PutUint32(header[4:], seq)
|
||||
binary.BigEndian.PutUint32(header[8:], ack)
|
||||
//nolint:mnd
|
||||
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
|
||||
header[13] = flags
|
||||
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
|
||||
const windowSize = 5840
|
||||
binary.BigEndian.PutUint16(header[14:], windowSize)
|
||||
// header[16:17] is the checksum, set later
|
||||
// header[18:19] is urgent pointer, not needed for our use case
|
||||
return header
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
|
||||
var pseudoHeader []byte
|
||||
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
|
||||
if isIPv6 {
|
||||
pseudoHeader = make([]byte, 40)
|
||||
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
|
||||
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
|
||||
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
|
||||
pseudoHeader[39] = 6 // Next Header (TCP)
|
||||
} else {
|
||||
pseudoHeader = make([]byte, 12)
|
||||
copy(pseudoHeader[0:4], ipHeader[12:16])
|
||||
copy(pseudoHeader[4:8], ipHeader[16:20])
|
||||
pseudoHeader[9] = 6
|
||||
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
|
||||
}
|
||||
|
||||
sum := uint32(0)
|
||||
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
|
||||
for i := 0; i < len(slice)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
|
||||
}
|
||||
if len(slice)%2 != 0 {
|
||||
sum += uint32(slice[len(slice)-1]) << 8
|
||||
}
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum) //nolint:gosec
|
||||
}
|
||||
|
||||
const (
|
||||
tcpFlagsOffset = 13
|
||||
rstFlag byte = 0x04
|
||||
synFlag byte = 0x02
|
||||
ackFlag byte = 0x10
|
||||
pshFlag byte = 0x08
|
||||
)
|
||||
|
||||
type packetType uint8
|
||||
|
||||
const (
|
||||
packetTypeSYN packetType = iota + 1
|
||||
packetTypeSYNACK
|
||||
packetTypeACK
|
||||
packetTypeRST
|
||||
)
|
||||
|
||||
func (p packetType) String() string {
|
||||
switch p {
|
||||
case packetTypeSYN:
|
||||
return "SYN"
|
||||
case packetTypeSYNACK:
|
||||
return "SYN-ACK"
|
||||
case packetTypeACK:
|
||||
return "ACK"
|
||||
case packetTypeRST:
|
||||
return "RST"
|
||||
default:
|
||||
panic("unknown packet type")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errTCPHeaderTooShort = errors.New("TCP header is too short")
|
||||
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
|
||||
)
|
||||
|
||||
// parseTCPHeader parses some elements from the TCP header.
|
||||
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
|
||||
if len(header) < int(constants.BaseTCPHeaderLength) {
|
||||
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
|
||||
}
|
||||
flags := header[tcpFlagsOffset]
|
||||
switch {
|
||||
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
|
||||
packetType = packetTypeSYN
|
||||
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
|
||||
packetType = packetTypeSYNACK
|
||||
case (flags & rstFlag) != 0:
|
||||
packetType = packetTypeRST
|
||||
case (flags & ackFlag) != 0:
|
||||
packetType = packetTypeACK
|
||||
default:
|
||||
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
|
||||
}
|
||||
|
||||
seq = binary.BigEndian.Uint32(header[4:8])
|
||||
ack = binary.BigEndian.Uint32(header[8:12])
|
||||
return packetType, seq, ack, nil
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
type tracker struct {
|
||||
fd fileDescriptor
|
||||
ipv4 bool
|
||||
mutex sync.RWMutex
|
||||
portsToDispatch map[uint32]dispatch
|
||||
}
|
||||
|
||||
type dispatch struct {
|
||||
replyCh chan<- []byte
|
||||
abort <-chan struct{}
|
||||
}
|
||||
|
||||
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
|
||||
return &tracker{
|
||||
fd: fd,
|
||||
ipv4: ipv4,
|
||||
portsToDispatch: make(map[uint32]dispatch),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
|
||||
buf := make([]byte, 4) //nolint:mnd
|
||||
binary.BigEndian.PutUint16(buf[0:2], localPort)
|
||||
binary.BigEndian.PutUint16(buf[2:4], remotePort)
|
||||
return binary.BigEndian.Uint32(buf)
|
||||
}
|
||||
|
||||
func (t *tracker) register(localPort, remotePort uint16,
|
||||
ch chan<- []byte, abort <-chan struct{},
|
||||
) {
|
||||
key := t.constructKey(localPort, remotePort)
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
t.portsToDispatch[key] = dispatch{
|
||||
replyCh: ch,
|
||||
abort: abort,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracker) unregister(localPort, remotePort uint16) {
|
||||
key := t.constructKey(localPort, remotePort)
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
delete(t.portsToDispatch, key)
|
||||
}
|
||||
|
||||
// listen listens for incoming TCP packets and dispatches them to the
|
||||
// correct channel based on the source and destination port.
|
||||
// If the context has a deadline associated, this one is used on the socket.
|
||||
// Note it returns a nil error on context cancellation.
|
||||
func (t *tracker) listen(ctx context.Context) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
for ctx.Err() == nil {
|
||||
if hasDeadline {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
err := setSocketTimeout(t.fd, remaining)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting socket receive timeout: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
reply := make([]byte, constants.MaxEthernetFrameSize)
|
||||
n, _, err := recvFrom(t.fd, reply, 0)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, syscall.EAGAIN),
|
||||
errors.Is(err, syscall.EWOULDBLOCK):
|
||||
pollSleep(ctx)
|
||||
continue
|
||||
case ctx.Err() != nil:
|
||||
// context canceled, stop listening so exit cleanly with no error
|
||||
return nil //nolint:nilerr
|
||||
default:
|
||||
return fmt.Errorf("receiving on socket: %w", err)
|
||||
}
|
||||
}
|
||||
reply = reply[:n]
|
||||
|
||||
if t.ipv4 {
|
||||
var ok bool
|
||||
reply, ok = stripIPv4Header(reply)
|
||||
if !ok {
|
||||
continue // not an IPv4 packet
|
||||
}
|
||||
}
|
||||
|
||||
const minTCPHeaderLength = 20
|
||||
if len(reply) < minTCPHeaderLength {
|
||||
continue
|
||||
}
|
||||
|
||||
srcPort := binary.BigEndian.Uint16(reply[0:2])
|
||||
dstPort := binary.BigEndian.Uint16(reply[2:4])
|
||||
key := t.constructKey(dstPort, srcPort)
|
||||
t.mutex.RLock()
|
||||
dispatch, exists := t.portsToDispatch[key]
|
||||
t.mutex.RUnlock()
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case dispatch.replyCh <- reply:
|
||||
case <-dispatch.abort:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pollSleep(ctx context.Context) {
|
||||
const sleepBetweenPolls = 10 * time.Millisecond
|
||||
timer := time.NewTimer(sleepBetweenPolls)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package test
|
||||
|
||||
import "math"
|
||||
|
||||
// MakeMTUsToTest determines a slice of MTU values to test
|
||||
// between minMTU and maxMTU inclusive. It creates an MTU
|
||||
// slice of length up to 11 MTUs such that:
|
||||
// - the first element is the minMTU
|
||||
// - the last element is the maxMTU
|
||||
// - elements in-between are separated as close to each other
|
||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||
// with a total search space of 1728 MTUs which is enough;
|
||||
// to find it in 2 searches requires 37 parallel queries which
|
||||
// could be blocked by firewalls.
|
||||
func MakeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||
const mtusLength = 11 // find the final MTU in 3 searches
|
||||
diff := maxMTU - minMTU
|
||||
switch {
|
||||
case minMTU > maxMTU:
|
||||
panic("minMTU > maxMTU")
|
||||
case diff <= mtusLength:
|
||||
mtus = make([]uint32, 0, diff)
|
||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||
mtus = append(mtus, mtu)
|
||||
}
|
||||
default:
|
||||
step := float64(diff) / float64(mtusLength-1)
|
||||
mtus = make([]uint32, 0, mtusLength)
|
||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||
}
|
||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||
}
|
||||
|
||||
return mtus
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package pmtud
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_makeMTUsToTest(t *testing.T) {
|
||||
func Test_MakeMTUsToTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -48,7 +48,7 @@ func Test_makeMTUsToTest(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
assert.Equal(t, testCase.mtus, mtus)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel
|
||||
// given the VPN type, network protocol, and VPN gateway IP address.
|
||||
// This is notably useful to skip testing MTU values higher than this value.
|
||||
// The function panics if the network or VPN type is unknown.
|
||||
func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 {
|
||||
const physicalLinkMTU = pconstants.MaxEthernetFrameSize
|
||||
vpnLinkMTU := physicalLinkMTU
|
||||
if vpnGateway.Is4() {
|
||||
vpnLinkMTU -= pconstants.IPv4HeaderLength
|
||||
} else {
|
||||
vpnLinkMTU -= pconstants.IPv6HeaderLength
|
||||
}
|
||||
switch network {
|
||||
case constants.TCP:
|
||||
vpnLinkMTU -= pconstants.BaseTCPHeaderLength
|
||||
case constants.UDP:
|
||||
vpnLinkMTU -= pconstants.UDPHeaderLength
|
||||
default:
|
||||
panic("unknown network protocol: " + network)
|
||||
}
|
||||
switch vpnType {
|
||||
case vpn.Wireguard:
|
||||
vpnLinkMTU -= pconstants.WireguardHeaderLength
|
||||
case vpn.OpenVPN:
|
||||
vpnLinkMTU -= pconstants.OpenVPNHeaderMaxLength
|
||||
default:
|
||||
panic("unknown VPN type: " + vpnType)
|
||||
}
|
||||
return vpnLinkMTU
|
||||
}
|
||||
@@ -16,7 +16,17 @@ func BuildWireguardSettings(connection models.Connection,
|
||||
settings.PreSharedKey = *userSettings.PreSharedKey
|
||||
settings.InterfaceName = userSettings.Interface
|
||||
settings.Implementation = userSettings.Implementation
|
||||
settings.MTU = userSettings.MTU
|
||||
if *userSettings.MTU > 0 {
|
||||
settings.MTU = *userSettings.MTU
|
||||
} else {
|
||||
// The default is 1320 which is NOT the wireguard-go default
|
||||
// of 1420 because this impacts bandwidth a lot on some
|
||||
// VPN providers, see https://github.com/qdm12/gluetun/issues/1650.
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue: https://github.com/qdm12/gluetun/issues/2533.
|
||||
const defaultMTU = 1320
|
||||
settings.MTU = defaultMTU
|
||||
}
|
||||
settings.IPv6 = &ipv6Supported
|
||||
|
||||
const rulePriority = 101 // 100 is to receive external connections
|
||||
|
||||
@@ -22,7 +22,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
ipv6Supported bool
|
||||
settings wireguard.Settings
|
||||
}{
|
||||
"some settings": {
|
||||
"some_settings": {
|
||||
connection: models.Connection{
|
||||
IP: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
|
||||
Port: 51821,
|
||||
@@ -41,6 +41,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
},
|
||||
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||
Interface: "wg1",
|
||||
MTU: ptrTo(uint32(1000)),
|
||||
},
|
||||
ipv6Supported: false,
|
||||
settings: wireguard.Settings{
|
||||
@@ -58,6 +59,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
PersistentKeepaliveInterval: time.Hour,
|
||||
RulePriority: 101,
|
||||
IPv6: boolPtr(false),
|
||||
MTU: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -47,3 +47,26 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
||||
}
|
||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
||||
}
|
||||
|
||||
var ErrVPNRouteNotFound = errors.New("VPN route not found")
|
||||
|
||||
func (r *Routing) VPNRoute(vpnIntf string) (route netlink.Route, err error) {
|
||||
vpnLink, err := r.netLinker.LinkByName(vpnIntf)
|
||||
if err != nil {
|
||||
return route, fmt.Errorf("finding link %s: %w", vpnIntf, err)
|
||||
}
|
||||
vpnLinkIndex := vpnLink.Index
|
||||
|
||||
routes, err := r.netLinker.RouteList(netlink.FamilyAll)
|
||||
if err != nil {
|
||||
return route, fmt.Errorf("listing routes: %w", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
if route.LinkIndex == vpnLinkIndex &&
|
||||
!route.Dst.IsValid() {
|
||||
return route, nil
|
||||
}
|
||||
}
|
||||
return route, fmt.Errorf("%w: for interface %s in %d routes",
|
||||
ErrVPNRouteNotFound, vpnIntf, len(routes))
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type Firewall interface {
|
||||
|
||||
type Routing interface {
|
||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||
VPNRoute(vpnIntf string) (route netlink.Route, err error)
|
||||
}
|
||||
|
||||
type PortForward interface {
|
||||
@@ -67,6 +68,7 @@ type NetLinker interface {
|
||||
type Router interface {
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
RouteReplace(route netlink.Route) error
|
||||
}
|
||||
|
||||
type Ruler interface {
|
||||
|
||||
@@ -47,7 +47,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
pmtud: tunnelUpPMTUDData{
|
||||
enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0,
|
||||
vpnType: settings.Type,
|
||||
network: connection.Protocol,
|
||||
icmpAddrs: settings.PMTUD.ICMPAddresses,
|
||||
tcpAddrs: settings.PMTUD.TCPAddresses,
|
||||
},
|
||||
serverIP: connection.IP,
|
||||
serverName: connection.ServerName,
|
||||
canPortForward: connection.PortForward,
|
||||
|
||||
+60
-28
@@ -2,7 +2,6 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud"
|
||||
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
"github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
@@ -17,9 +17,7 @@ import (
|
||||
type tunnelUpData struct {
|
||||
// Healthcheck
|
||||
serverIP netip.Addr
|
||||
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||
// It can be "wireguard" or "openvpn".
|
||||
vpnType string
|
||||
pmtud tunnelUpPMTUDData
|
||||
// Port forwarding
|
||||
vpnIntf string
|
||||
serverName string // used for PIA
|
||||
@@ -29,6 +27,23 @@ type tunnelUpData struct {
|
||||
portForwarder PortForwarder
|
||||
}
|
||||
|
||||
type tunnelUpPMTUDData struct {
|
||||
// enabled is notably false if the user specifies a custom MTU.
|
||||
enabled bool
|
||||
// vpnType is used to find the maximum VPN header overhead.
|
||||
// It can be [vpn.Wireguard] or [vpn.OpenVPN].
|
||||
vpnType string
|
||||
// network is used to find the network level header overhead.
|
||||
// It can be [constants.UDP] or [constants.TCP].
|
||||
network string
|
||||
// icmpAddrs is the list of addresses to use for ICMP path MTU discovery.
|
||||
// Each address should handle ICMP packets for PMTUD to work.
|
||||
icmpAddrs []netip.Addr
|
||||
// tcpAddrs is the list of addresses to use for TCP path MTU discovery.
|
||||
// Each address should have a listening TCP server on the port specified.
|
||||
tcpAddrs []netip.AddrPort
|
||||
}
|
||||
|
||||
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
l.client.CloseIdleConnections()
|
||||
|
||||
@@ -39,12 +54,15 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
if data.pmtud.enabled {
|
||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
|
||||
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
|
||||
l.netLinker, l.routing, mtuLogger)
|
||||
if err != nil {
|
||||
mtuLogger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
||||
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
|
||||
@@ -136,12 +154,11 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||
|
||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||
netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||
) error {
|
||||
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
||||
logger.Info("finding maximum MTU, this can take up to 6 seconds")
|
||||
|
||||
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
||||
if err != nil {
|
||||
@@ -155,18 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
|
||||
originalMTU := link.MTU
|
||||
|
||||
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||
const physicalLinkMTU uint32 = 1500
|
||||
vpnLinkMTU := physicalLinkMTU
|
||||
switch vpnType {
|
||||
case "wireguard":
|
||||
vpnLinkMTU -= 60 // Wireguard overhead
|
||||
case "openvpn":
|
||||
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||
default:
|
||||
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||
}
|
||||
vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, vpnGatewayIP)
|
||||
|
||||
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||
@@ -178,16 +184,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
}
|
||||
|
||||
const pingTimeout = time.Second
|
||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
|
||||
vpnLinkMTU, pingTimeout, logger)
|
||||
if err != nil {
|
||||
vpnLinkMTU = originalMTU
|
||||
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||
vpnInterface, originalMTU, err)
|
||||
default:
|
||||
return fmt.Errorf("path MTU discovering: %w", err)
|
||||
} else {
|
||||
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||
}
|
||||
|
||||
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||
@@ -195,5 +199,33 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
err = setTCPMSSOnVPNRoute(vpnInterface, vpnLinkMTU, routing, netlinker)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting safe TCP MSS for MTU %d: %w", vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setTCPMSSOnVPNRoute(vpnIntf string, mtu uint32,
|
||||
routing Routing, netlinker NetLinker,
|
||||
) error {
|
||||
route, err := routing.VPNRoute(vpnIntf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN route: %w", err)
|
||||
}
|
||||
|
||||
ipHeaderLength := pconstants.IPv4HeaderLength
|
||||
if route.Dst.Addr().Is6() {
|
||||
ipHeaderLength = pconstants.IPv6HeaderLength
|
||||
}
|
||||
const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps
|
||||
overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead
|
||||
mss := mtu - overhead
|
||||
route.AdvMSS = mss
|
||||
err = netlinker.RouteReplace(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing VPN route with MSS changed to %d: %w", mss, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user