diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 35e96b33..8ac79bf4 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,2 +1,2 @@ FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine -RUN apk add wireguard-tools htop openssl +RUN apk add wireguard-tools htop openssl tcpdump diff --git a/.golangci.yml b/.golangci.yml index 9b6af6a4..0fd6c949 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -22,6 +22,7 @@ linters: - "^disabled$" # Firewall and routing strings - "^(ACCEPT|DROP)$" + - "^--append$" - "^--delete$" - "^all$" - "^(tcp|udp)$" diff --git a/Dockerfile b/Dockerfile index 0e1becef..be4f3f82 100644 --- a/Dockerfile +++ b/Dockerfile @@ -110,8 +110,11 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \ WIREGUARD_ADDRESSES= \ WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \ - WIREGUARD_MTU=1320 \ + WIREGUARD_MTU= \ WIREGUARD_IMPLEMENTATION=auto \ + # PMTUD + PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \ + PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443 \ # VPN server filtering SERVER_REGIONS= \ SERVER_COUNTRIES= \ diff --git a/internal/configuration/settings/pmtud.go b/internal/configuration/settings/pmtud.go new file mode 100644 index 00000000..a5692d55 --- /dev/null +++ b/internal/configuration/settings/pmtud.go @@ -0,0 +1,108 @@ +package settings + +import ( + "errors" + "fmt" + "net/netip" + "strings" + + "github.com/qdm12/gosettings" + "github.com/qdm12/gosettings/reader" + "github.com/qdm12/gotree" +) + +// PMTUD contains settings to configure Path MTU Discovery. +type PMTUD struct { + // ICMPAddresses is the redundancy list of addresses to use + // for ICMP path MTU discovery. Each address MUST handle ICMP + // packets for PMTUD to work. + // It cannot be nil in the internal state. + ICMPAddresses []netip.Addr `json:"icmp_addresses"` + // TCPAddresses is the redundancy list of addresses to use + // for TCP path MTU discovery. Each address MUST have a listening + // TCP server on the port specified. + // It cannot be nil in the internal state. + TCPAddresses []netip.AddrPort `json:"tcp_addresses"` +} + +var ( + ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid") + ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid") +) + +// Validate validates PMTUD settings. +func (p PMTUD) validate() (err error) { + for i, addr := range p.ICMPAddresses { + if !addr.IsValid() { + return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i) + } + } + for i, addr := range p.TCPAddresses { + if !addr.IsValid() { + return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i) + } + } + return nil +} + +func (p *PMTUD) copy() (copied PMTUD) { + return PMTUD{ + ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses), + TCPAddresses: gosettings.CopySlice(p.TCPAddresses), + } +} + +func (p *PMTUD) overrideWith(other PMTUD) { + p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses) + p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses) +} + +func (p *PMTUD) setDefaults() { + defaultICMPAddresses := []netip.Addr{ + netip.AddrFrom4([4]byte{1, 1, 1, 1}), + netip.AddrFrom4([4]byte{8, 8, 8, 8}), + } + p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses) + + const tlsPort = 443 + defaultTCPAddresses := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort), + netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort), + } + p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses) +} + +func (p PMTUD) String() string { + return p.toLinesNode().String() +} + +func (p PMTUD) toLinesNode() (node *gotree.Node) { + node = gotree.New("Path MTU discovery:") + + addrs := make([]string, len(p.ICMPAddresses)) + for i, addr := range p.ICMPAddresses { + addrs[i] = addr.String() + } + node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", ")) + + addrs = make([]string, len(p.TCPAddresses)) + for i, addr := range p.TCPAddresses { + addrs[i] = addr.String() + } + node.Appendf("TCP addresses: %s", strings.Join(addrs, ", ")) + return node +} + +func (p *PMTUD) read(r *reader.Reader) (err error) { + p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES") + if err != nil { + return err + } + + p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES") + if err != nil { + return err + } + + return nil +} diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index 4f051877..fa865446 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -29,14 +29,17 @@ func Test_Settings_String(t *testing.T) { | | └── OpenVPN server selection settings: | | ├── Protocol: UDP | | └── Private Internet Access encryption preset: strong -| └── OpenVPN settings: -| ├── OpenVPN version: 2.6 -| ├── User: [not set] -| ├── Password: [not set] -| ├── Private Internet Access encryption preset: strong -| ├── Network interface: tun0 -| ├── Run OpenVPN as: root -| └── Verbosity level: 1 +| ├── OpenVPN settings: +| | ├── OpenVPN version: 2.6 +| | ├── User: [not set] +| | ├── Password: [not set] +| | ├── Private Internet Access encryption preset: strong +| | ├── Network interface: tun0 +| | ├── Run OpenVPN as: root +| | └── Verbosity level: 1 +| └── Path MTU discovery: +| ├── ICMP addresses: 1.1.1.1, 8.8.8.8 +| └── TCP addresses: 1.1.1.1:443, 8.8.8.8:443 ├── DNS settings: | ├── Keep existing nameserver(s): no | ├── DNS server address to use: 127.0.0.1 diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index aec51543..d8aa6f1c 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -18,6 +18,7 @@ type VPN struct { Provider Provider `json:"provider"` OpenVPN OpenVPN `json:"openvpn"` Wireguard Wireguard `json:"wireguard"` + PMTUD PMTUD `json:"pmtud"` } // TODO v4 remove pointer for receiver (because of Surfshark). @@ -45,6 +46,11 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo } } + err = v.PMTUD.validate() + if err != nil { + return fmt.Errorf("PMTUD settings: %w", err) + } + return nil } @@ -54,6 +60,7 @@ func (v *VPN) Copy() (copied VPN) { Provider: v.Provider.copy(), OpenVPN: v.OpenVPN.copy(), Wireguard: v.Wireguard.copy(), + PMTUD: v.PMTUD.copy(), } } @@ -62,6 +69,7 @@ func (v *VPN) OverrideWith(other VPN) { v.Provider.overrideWith(other.Provider) v.OpenVPN.overrideWith(other.OpenVPN) v.Wireguard.overrideWith(other.Wireguard) + v.PMTUD.overrideWith(other.PMTUD) } func (v *VPN) setDefaults() { @@ -69,6 +77,7 @@ func (v *VPN) setDefaults() { v.Provider.setDefaults() v.OpenVPN.setDefaults(v.Provider.Name) v.Wireguard.setDefaults(v.Provider.Name) + v.PMTUD.setDefaults() } func (v VPN) String() string { @@ -85,6 +94,7 @@ func (v VPN) toLinesNode() (node *gotree.Node) { } else { node.AppendNode(v.Wireguard.toLinesNode()) } + node.AppendNode(v.PMTUD.toLinesNode()) return node } @@ -107,5 +117,10 @@ func (v *VPN) read(r *reader.Reader) (err error) { return fmt.Errorf("wireguard: %w", err) } + err = v.PMTUD.read(r) + if err != nil { + return fmt.Errorf("PMTUD: %w", err) + } + return nil } diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index bd76c096..b096968b 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -38,15 +38,9 @@ type Wireguard struct { Interface string `json:"interface"` PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"` // Maximum Transmission Unit (MTU) of the Wireguard interface. - // It cannot be zero in the internal state, and defaults to - // 1320. Note it is not the wireguard-go MTU default of 1420 - // because this impacts bandwidth a lot on some VPN providers, - // see https://github.com/qdm12/gluetun/issues/1650. - // It has been lowered to 1320 following quite a bit of - // investigation in the issue: - // https://github.com/qdm12/gluetun/issues/2533. - // Note this should now be replaced with the PMTUD feature. - MTU uint32 `json:"mtu"` + // It cannot be nil in the internal state, and defaults to + // 0 indicating to use PMTUD. + MTU *uint32 `json:"mtu"` // Implementation is the Wireguard implementation to use. // It can be "auto", "userspace" or "kernelspace". // It defaults to "auto" and cannot be the empty string @@ -195,8 +189,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) { w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs) w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0) w.Interface = gosettings.DefaultComparable(w.Interface, "wg0") - const defaultMTU = 1320 - w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU) + w.MTU = gosettings.DefaultPointer(w.MTU, 0) w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto") } @@ -232,7 +225,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) { } interfaceNode := node.Appendf("Network interface: %s", w.Interface) - interfaceNode.Appendf("MTU: %d", w.MTU) + if *w.MTU == 0 { + interfaceNode.Append("MTU: use path MTU discovery") + } else { + interfaceNode.Appendf("MTU: %d", *w.MTU) + } if w.Implementation != "auto" { node.Appendf("Implementation: %s", w.Implementation) @@ -273,11 +270,9 @@ func (w *Wireguard) read(r *reader.Reader) (err error) { return err } - mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU") + w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU") if err != nil { return err - } else if mtuPtr != nil { - w.MTU = *mtuPtr } return nil } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index da9f0b64..72b44a02 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -29,17 +29,16 @@ func appendOrDelete(remove bool) string { // flipRule changes an append rule in a delete rule or a delete rule into an // append rule. func flipRule(rule string) string { - switch { - case strings.HasPrefix(rule, "-A"): - return strings.Replace(rule, "-A", "-D", 1) - case strings.HasPrefix(rule, "--append"): - return strings.Replace(rule, "--append", "-D", 1) - case strings.HasPrefix(rule, "-D"): - return strings.Replace(rule, "-D", "-A", 1) - case strings.HasPrefix(rule, "--delete"): - return strings.Replace(rule, "--delete", "-A", 1) + fields := strings.Fields(rule) + for i, field := range fields { + switch field { + case "-A", "--append": + fields[i] = "--delete" + case "-D", "--delete": + fields[i] = "--append" + } } - return rule + return strings.Join(fields, " ") } // Version obtains the version of the installed iptables. @@ -86,10 +85,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) } func (c *Config) clearAllRules(ctx context.Context) error { - return c.runMixedIptablesInstructions(ctx, []string{ - "--flush", // flush all chains - "--delete-chain", // delete all chains - }) + tables := []string{"filter"} + for _, table := range tables { + return c.runMixedIptablesInstructions(ctx, []string{ + "-t " + table + " --flush", // flush all chains + "-t " + table + " --delete-chain", // delete all chains + }) + } + return nil } func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error { diff --git a/internal/netlink/route.go b/internal/netlink/route.go index 5fd2774e..cc9b132b 100644 --- a/internal/netlink/route.go +++ b/internal/netlink/route.go @@ -18,6 +18,7 @@ type Route struct { Type uint8 Scope uint8 Proto uint8 + AdvMSS uint32 } func (r *Route) fromMessage(message rtnetlink.RouteMessage) { @@ -35,6 +36,9 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) { r.Type = message.Type r.Scope = message.Scope r.Proto = message.Protocol + if metrics := message.Attributes.Metrics; metrics != nil { + r.AdvMSS = metrics.AdvMSS + } } func (r Route) message() *rtnetlink.RouteMessage { @@ -58,7 +62,6 @@ func (r Route) message() *rtnetlink.RouteMessage { Protocol: r.Proto, Attributes: rtnetlink.RouteAttributes{ OutIface: r.LinkIndex, - Dst: *dst, // there should always be a dst for routes Gateway: netipAddrToNetIP(r.Gw), Priority: r.Priority, Table: extendedTable, @@ -67,6 +70,15 @@ func (r Route) message() *rtnetlink.RouteMessage { if src != nil { // src is optional message.Attributes.Src = *src } + if dst != nil { + message.Attributes.Dst = *dst + } + if r.AdvMSS != 0 { + if message.Attributes.Metrics == nil { + message.Attributes.Metrics = &rtnetlink.RouteMetrics{} + } + message.Attributes.Metrics.AdvMSS = r.AdvMSS + } return message } diff --git a/internal/pmtud/constants/lengths.go b/internal/pmtud/constants/lengths.go new file mode 100644 index 00000000..dc5cbcf9 --- /dev/null +++ b/internal/pmtud/constants/lengths.go @@ -0,0 +1,24 @@ +package constants + +const ( + MaxEthernetFrameSize uint32 = 1500 + // MinIPv4MTU is defined according to + // https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media + MinIPv4MTU uint32 = 68 + MinIPv6MTU uint32 = 1280 + + IPv4HeaderLength uint32 = 20 + IPv6HeaderLength uint32 = 40 + UDPHeaderLength uint32 = 8 + // BaseTCPHeaderLength is the TCP header length without options, + // which is the minimum TCP header length. + BaseTCPHeaderLength uint32 = 20 + // MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes. + // Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60). + MaxTCPHeaderLength uint32 = 60 + WireguardHeaderLength uint32 = 32 + OpenVPNHeaderMaxLength uint32 = 1 + // opcode + 8 + // session id + 4 + // packet id + 28 // max possible auth tag/iv +) diff --git a/internal/pmtud/errors.go b/internal/pmtud/errors.go deleted file mode 100644 index 5f6eaa41..00000000 --- a/internal/pmtud/errors.go +++ /dev/null @@ -1,29 +0,0 @@ -package pmtud - -import ( - "context" - "errors" - "fmt" - "net" - "strings" - "time" -) - -var ( - ErrICMPNotPermitted = errors.New("ICMP not permitted") - ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") - ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited") - ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") -) - -func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive - switch { - case strings.HasSuffix(err.Error(), "sendto: operation not permitted"): - err = fmt.Errorf("%w", ErrICMPNotPermitted) - case errors.Is(timedCtx.Err(), context.DeadlineExceeded): - err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout) - case timedCtx.Err() != nil: - err = timedCtx.Err() - } - return err -} diff --git a/internal/pmtud/apple_ipv4.go b/internal/pmtud/icmp/apple_ipv4.go similarity index 98% rename from internal/pmtud/apple_ipv4.go rename to internal/pmtud/icmp/apple_ipv4.go index 6b298d79..7f9c6484 100644 --- a/internal/pmtud/apple_ipv4.go +++ b/internal/pmtud/icmp/apple_ipv4.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( "net" diff --git a/internal/pmtud/check.go b/internal/pmtud/icmp/check.go similarity index 71% rename from internal/pmtud/check.go rename to internal/pmtud/icmp/check.go index a185a720..72e6dc3d 100644 --- a/internal/pmtud/check.go +++ b/internal/pmtud/icmp/check.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( "bytes" @@ -9,17 +9,17 @@ import ( ) var ( - ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low") - ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high") + ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low") + ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high") ) func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) { switch { case mtu < minMTU: - return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu) + return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu) case mtu > physicalLinkMTU: return fmt.Errorf("%w: %d is larger than physical link MTU %d", - ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU) + ErrNextHopMTUTooHigh, mtu, physicalLinkMTU) default: return nil } @@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte, } inboundBody, ok := inboundMessage.Body.(*icmp.Echo) if !ok { - return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body) } outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert return inboundBody.ID == outboundBody.ID, nil } -var ErrICMPIDMismatch = errors.New("ICMP id mismatch") +var ErrIDMismatch = errors.New("ICMP id mismatch") func checkEchoReply(icmpProtocol int, received []byte, outboundMessage *icmp.Message, truncatedBody bool, @@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte, } inboundBody, ok := inboundMessage.Body.(*icmp.Echo) if !ok { - return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body) + return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body) } outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert if inboundBody.ID != outboundBody.ID { return fmt.Errorf("%w: sent id %d and received id %d", - ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID) + ErrIDMismatch, outboundBody.ID, inboundBody.ID) } err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody) if err != nil { @@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte, return nil } -var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch") +var ErrEchoDataMismatch = errors.New("ICMP data mismatch") func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) { if len(received) > len(sent) { return fmt.Errorf("%w: sent %d bytes and received %d bytes", - ErrICMPEchoDataMismatch, len(sent), len(received)) + ErrEchoDataMismatch, len(sent), len(received)) } if receivedTruncated { sent = sent[:len(received)] } if !bytes.Equal(received, sent) { return fmt.Errorf("%w: sent %x and received %x", - ErrICMPEchoDataMismatch, sent, received) + ErrEchoDataMismatch, sent, received) } return nil } diff --git a/internal/pmtud/df.go b/internal/pmtud/icmp/df.go similarity index 94% rename from internal/pmtud/df.go rename to internal/pmtud/icmp/df.go index 9e6ee59d..32c3ea32 100644 --- a/internal/pmtud/df.go +++ b/internal/pmtud/icmp/df.go @@ -1,6 +1,6 @@ //go:build !linux && !windows -package pmtud +package icmp // setDontFragment for platforms other than Linux and Windows // is not implemented, so we just return assuming the don't diff --git a/internal/pmtud/df_linux.go b/internal/pmtud/icmp/df_linux.go similarity index 93% rename from internal/pmtud/df_linux.go rename to internal/pmtud/icmp/df_linux.go index 08c7979c..d4334aff 100644 --- a/internal/pmtud/df_linux.go +++ b/internal/pmtud/icmp/df_linux.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( "syscall" diff --git a/internal/pmtud/df_windows.go b/internal/pmtud/icmp/df_windows.go similarity index 90% rename from internal/pmtud/df_windows.go rename to internal/pmtud/icmp/df_windows.go index a8c98fc4..3416ff51 100644 --- a/internal/pmtud/df_windows.go +++ b/internal/pmtud/icmp/df_windows.go @@ -1,6 +1,4 @@ -//go:build windows - -package pmtud +package icmp import ( "syscall" diff --git a/internal/pmtud/icmp/errors.go b/internal/pmtud/icmp/errors.go new file mode 100644 index 00000000..277c5551 --- /dev/null +++ b/internal/pmtud/icmp/errors.go @@ -0,0 +1,30 @@ +package icmp + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" +) + +var ( + ErrNotPermitted = errors.New("ICMP not permitted") + ErrDestinationUnreachable = errors.New("ICMP destination unreachable") + ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited") + ErrBodyUnsupported = errors.New("ICMP body type is not supported") + ErrMTUNotFound = errors.New("MTU not found") +) + +func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive + switch { + case strings.HasSuffix(err.Error(), "sendto: operation not permitted"): + err = fmt.Errorf("%w", ErrNotPermitted) + case errors.Is(timedCtx.Err(), context.DeadlineExceeded): + err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout) + case timedCtx.Err() != nil: + err = timedCtx.Err() + } + return err +} diff --git a/internal/pmtud/icmp/icmp.go b/internal/pmtud/icmp/icmp.go new file mode 100644 index 00000000..8e9b9acb --- /dev/null +++ b/internal/pmtud/icmp/icmp.go @@ -0,0 +1,53 @@ +package icmp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "time" + + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +// PathMTUDiscover discovers the path MTU to the given IP address +// using ICMP. +// It first tries to get the next hop MTU using ICMP messages. +// If that fails, it falls back to sending echo requests with +// different packet sizes to find the maximum MTU. +// The function returns [ErrMTUNotFound] if the MTU could not be determined. +func PathMTUDiscover(ctx context.Context, ip netip.Addr, + physicalLinkMTU uint32, timeout time.Duration, logger Logger, +) (mtu uint32, err error) { + if ip.Is4() { + logger.Debug("finding IPv4 next hop MTU") + mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole + default: + return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err) + } + } else { + logger.Debug("requesting IPv6 ICMP packet-too-big reply") + mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger) + switch { + case err == nil: + return mtu, nil + case errors.Is(err, net.ErrClosed): // blackhole + default: + return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err) + } + } + + // Fall back method: send echo requests with different packet + // sizes and check which ones succeed to find the maximum MTU. + logger.Debug("falling back to sending different sized echo packets") + minMTU := constants.MinIPv4MTU + if ip.Is6() { + minMTU = constants.MinIPv6MTU + } + return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger) +} diff --git a/internal/pmtud/icmp/interfaces.go b/internal/pmtud/icmp/interfaces.go new file mode 100644 index 00000000..37c32482 --- /dev/null +++ b/internal/pmtud/icmp/interfaces.go @@ -0,0 +1,7 @@ +package icmp + +type Logger interface { + Debug(msg string) + Debugf(msg string, args ...any) + Warnf(msg string, args ...any) +} diff --git a/internal/pmtud/ipv4.go b/internal/pmtud/icmp/ipv4.go similarity index 86% rename from internal/pmtud/ipv4.go rename to internal/pmtud/icmp/ipv4.go index 7e436847..684ea4a1 100644 --- a/internal/pmtud/ipv4.go +++ b/internal/pmtud/icmp/ipv4.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( "context" @@ -11,14 +11,13 @@ import ( "syscall" "time" + "github.com/qdm12/gluetun/internal/pmtud/constants" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" ) const ( - // see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media - minIPv4MTU uint32 = 68 - icmpv4Protocol int = 1 + icmpv4Protocol = 1 ) func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { @@ -38,7 +37,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress) if err != nil { if strings.HasSuffix(err.Error(), "socket: operation not permitted") { - err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted) } return nil, err } @@ -83,7 +82,9 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, buffer := make([]byte, physicalLinkMTU) - for { // for loop in case we read an echo reply for another ICMP request + // for loop in case we read an ICMP message from another ICMP request + // or TCP/UDP traffic triggering an ICMP response. + for { // Note we need to read the whole packet in one call to ReadFrom, so the buffer // must be large enough to read the entire reply packet. See: // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J @@ -108,24 +109,27 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, switch typedBody := inboundMessage.Body.(type) { case *icmp.DstUnreach: const fragmentationRequiredAndDFFlagSetCode = 4 + const portUnreachable = 3 const communicationAdministrativelyProhibitedCode = 13 switch inboundMessage.Code { case fragmentationRequiredAndDFFlagSetCode: + case portUnreachable: // triggered by TCP or UDP from applications + continue // ignore and wait for the next message case communicationAdministrativelyProhibitedCode: return 0, fmt.Errorf("%w: %w (code %d)", - ErrICMPDestinationUnreachable, - ErrICMPCommunicationAdministrativelyProhibited, + ErrDestinationUnreachable, + ErrCommunicationAdministrativelyProhibited, inboundMessage.Code) default: return 0, fmt.Errorf("%w: code %d", - ErrICMPDestinationUnreachable, inboundMessage.Code) + ErrDestinationUnreachable, inboundMessage.Code) } // See https://datatracker.ietf.org/doc/html/rfc1191#section-4 // Note: the go library does not handle this NextHopMTU section. nextHopMTU := packetBytes[6:8] mtu = uint32(binary.BigEndian.Uint16(nextHopMTU)) - err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU) + err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU) if err != nil { return 0, fmt.Errorf("checking next-hop-mtu found: %w", err) } @@ -153,7 +157,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, inboundID, outboundID) continue default: - return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody) } } } diff --git a/internal/pmtud/ipv6.go b/internal/pmtud/icmp/ipv6.go similarity index 92% rename from internal/pmtud/ipv6.go rename to internal/pmtud/icmp/ipv6.go index 787f4590..a707bed3 100644 --- a/internal/pmtud/ipv6.go +++ b/internal/pmtud/icmp/ipv6.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( "context" @@ -8,12 +8,12 @@ import ( "strings" "time" + "github.com/qdm12/gluetun/internal/pmtud/constants" "golang.org/x/net/icmp" "golang.org/x/net/ipv6" ) const ( - minIPv6MTU = 1280 icmpv6Protocol = 58 ) @@ -23,7 +23,7 @@ func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) { packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress) if err != nil { if strings.HasSuffix(err.Error(), "socket: operation not permitted") { - err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted) } return nil, err } @@ -85,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, case *icmp.PacketTooBig: // https://datatracker.ietf.org/doc/html/rfc1885#section-3.2 mtu = uint32(typedBody.MTU) //nolint:gosec - err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU) + err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU) if err != nil { return 0, fmt.Errorf("checking MTU: %w", err) } @@ -103,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, if err != nil { return 0, fmt.Errorf("checking invoking message id: %w", err) } else if idMatch { - return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable) + return 0, fmt.Errorf("%w", ErrDestinationUnreachable) } logger.Debug("discarding received ICMP destination unreachable reply with an unknown id") continue @@ -116,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, inboundID, outboundID) continue default: - return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody) + return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody) } } } diff --git a/internal/pmtud/message.go b/internal/pmtud/icmp/message.go similarity index 99% rename from internal/pmtud/message.go rename to internal/pmtud/icmp/message.go index a216ae00..10788ce0 100644 --- a/internal/pmtud/message.go +++ b/internal/pmtud/icmp/message.go @@ -1,4 +1,4 @@ -package pmtud +package icmp import ( cryptorand "crypto/rand" diff --git a/internal/pmtud/icmp/multi.go b/internal/pmtud/icmp/multi.go new file mode 100644 index 00000000..cdf3a692 --- /dev/null +++ b/internal/pmtud/icmp/multi.go @@ -0,0 +1,187 @@ +package icmp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "time" + + "github.com/qdm12/gluetun/internal/pmtud/test" + "golang.org/x/net/icmp" +) + +type icmpTestUnit struct { + mtu uint32 + echoID uint16 + sentBytes int + ok bool +} + +func pmtudMultiSizes(ctx context.Context, ip netip.Addr, + minMTU, maxPossibleMTU uint32, pingTimeout time.Duration, + logger Logger, +) (maxMTU uint32, err error) { + var ipVersion string + var conn net.PacketConn + if ip.Is4() { + ipVersion = "v4" + conn, err = listenICMPv4(ctx) + } else { + ipVersion = "v6" + conn, err = listenICMPv6(ctx) + } + if err != nil { + if strings.HasSuffix(err.Error(), "socket: operation not permitted") { + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted) + } + return 0, fmt.Errorf("listening for ICMP packets: %w", err) + } + + mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU) + if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU + return minMTU, nil + } + logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest) + + tests := make([]icmpTestUnit, len(mtusToTest)) + for i := range mtusToTest { + tests[i] = icmpTestUnit{mtu: mtusToTest[i]} + } + + timedCtx, cancel := context.WithTimeout(ctx, pingTimeout) + defer cancel() + go func() { + <-timedCtx.Done() + conn.Close() + }() + + for i := range tests { + id, message := buildMessageToSend(ipVersion, tests[i].mtu) + tests[i].echoID = id + + encodedMessage, err := message.Marshal(nil) + if err != nil { + return 0, fmt.Errorf("encoding ICMP message: %w", err) + } + tests[i].sentBytes = len(encodedMessage) + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { + err = fmt.Errorf("%w", ErrNotPermitted) + } + return 0, fmt.Errorf("writing ICMP message: %w", err) + } + } + + err = collectReplies(conn, ipVersion, tests, logger) + switch { + case err == nil: // max possible MTU is working + return tests[len(tests)-1].mtu, nil + case err != nil && errors.Is(err, net.ErrClosed): + // we have timeouts (IPv4 testing or IPv6 PMTUD blackholes) + // so find the highest MTU which worked. + // Note we start from index len(tests) - 2 since the max MTU + // cannot be working if we had a timeout. + for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd + if tests[i].ok { + return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1, + pingTimeout, logger) + } + } + + // All MTUs failed. + return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound) + case err != nil: + return 0, fmt.Errorf("collecting ICMP echo replies: %w", err) + default: + panic("unreachable") + } +} + +// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would +// create huge buffers which we don't really want to support anyway. +// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with +// a conventional maximum of 9000 bytes. However, some manufacturers support up +// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to +// match eventual Jumbo frames. More information at: +// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media +const maxPossibleMTU = 9196 + +func collectReplies(conn net.PacketConn, ipVersion string, + tests []icmpTestUnit, logger Logger, +) (err error) { + echoIDToTestIndex := make(map[uint16]int, len(tests)) + for i, test := range tests { + echoIDToTestIndex[test.echoID] = i + } + + buffer := make([]byte, maxPossibleMTU) + + idsFound := 0 + for idsFound < len(tests) { + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + ipPacketLength := len(packetBytes) + + var icmpProtocol int + switch ipVersion { + case "v4": + icmpProtocol = icmpv4Protocol + case "v6": + icmpProtocol = icmpv6Protocol + default: + panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) + } + + // Parse the ICMP message + // Note: this parsing works for a truncated 556 bytes ICMP reply packet. + message, err := icmp.ParseMessage(icmpProtocol, packetBytes) + if err != nil { + return fmt.Errorf("parsing message: %w", err) + } + + echoBody, ok := message.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body) + } + + id := uint16(echoBody.ID) //nolint:gosec + testIndex, testing := echoIDToTestIndex[id] + if !testing { // not an id we expected so ignore it + logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)", + echoBody.ID, message.Type, message.Code, ipPacketLength) + continue + } + idsFound++ + sentBytes := tests[testIndex].sentBytes + + // echo reply should be at most the number of bytes sent, + // and can be lower, more precisely 556 bytes, in case + // the host we are reaching wants to stay out of trouble + // and ensure its echo reply goes through without + // fragmentation, see the following page: + // https://datatracker.ietf.org/doc/html/rfc1122#page-59 + const conservativeReplyLength = 556 + truncated := ipPacketLength < sentBytes && + ipPacketLength == conservativeReplyLength + // Check the packet size is the same if the reply is not truncated + if !truncated && sentBytes != ipPacketLength { + return fmt.Errorf("%w: sent %dB and received %dB", + ErrEchoDataMismatch, sentBytes, ipPacketLength) + } + // Truncated reply or matching reply size + tests[testIndex].ok = true + } + return nil +} diff --git a/internal/pmtud/ip/ipheader.go b/internal/pmtud/ip/ipheader.go new file mode 100644 index 00000000..89d86623 --- /dev/null +++ b/internal/pmtud/ip/ipheader.go @@ -0,0 +1,73 @@ +package ip + +import ( + "encoding/binary" + "net/netip" + "syscall" + + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte { + ipHeader := make([]byte, constants.IPv4HeaderLength) + const version byte = 4 + const headerLength byte = 20 / 4 // in 32-bit words + ipHeader[0] = (version << 4) | headerLength //nolint:mnd + ipHeader[1] = 0 // type of Service + putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec + ipHeader[4], ipHeader[5] = 0, 0 // identification + const flagsAndOffset uint16 = 0x4000 // DF bit set + putUint16(ipHeader[6:], flagsAndOffset) + ipHeader[8] = 64 // ttl + ipHeader[9] = syscall.IPPROTO_TCP + srcIPBytes := srcIP.As4() + copy(ipHeader[12:16], srcIPBytes[:]) + dstIPBytes := dstIP.As4() + copy(ipHeader[16:20], dstIPBytes[:]) + + checksum := ipChecksum(ipHeader) + ipHeader[10] = byte(checksum >> 8) //nolint:mnd + ipHeader[11] = byte(checksum & 0xff) //nolint:mnd + + return ipHeader +} + +// ipChecksum calculates the checksum for the IP header. +// +//nolint:mnd +func ipChecksum(header []byte) uint16 { + sum := uint32(0) + for i := 0; i < len(header)-1; i += 2 { + sum += uint32(header[i])<<8 + uint32(header[i+1]) + } + if len(header)%2 != 0 { + sum += uint32(header[len(header)-1]) << 8 + } + for (sum >> 16) > 0 { + sum = (sum & 0xFFFF) + (sum >> 16) + } + return ^uint16(sum) //nolint:gosec +} + +// HeaderV6 makes an IPv6 header. +// payloadLen is the length of the payload following the header. +// nextHeader can be byte([syscall.IPPROTO_TCP]) for example. +func HeaderV6(srcIP, dstIP netip.Addr, + payloadLen uint16, nextHeader byte, +) []byte { + ipv6Header := make([]byte, constants.IPv6HeaderLength) + ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits) + ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits) + + // Flow Label (remaining 16 bits) + ipv6Header[2] = 0x00 + ipv6Header[3] = 0x00 + + binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen) + ipv6Header[6] = nextHeader + const hopLimit = 64 + ipv6Header[7] = hopLimit + copy(ipv6Header[8:24], srcIP.AsSlice()) + copy(ipv6Header[24:40], dstIP.AsSlice()) + return ipv6Header +} diff --git a/internal/pmtud/ip/ipheader_darwin.go b/internal/pmtud/ip/ipheader_darwin.go new file mode 100644 index 00000000..a8c11ff7 --- /dev/null +++ b/internal/pmtud/ip/ipheader_darwin.go @@ -0,0 +1,9 @@ +package ip + +import ( + "encoding/binary" +) + +func putUint16(b []byte, v uint16) { + binary.NativeEndian.PutUint16(b, v) +} diff --git a/internal/pmtud/ip/ipheader_unspecified.go b/internal/pmtud/ip/ipheader_unspecified.go new file mode 100644 index 00000000..c2f702a1 --- /dev/null +++ b/internal/pmtud/ip/ipheader_unspecified.go @@ -0,0 +1,9 @@ +//go:build !darwin + +package ip + +import "encoding/binary" + +func putUint16(b []byte, v uint16) { + binary.BigEndian.PutUint16(b, v) +} diff --git a/internal/pmtud/ip/ipv4_unix.go b/internal/pmtud/ip/ipv4_unix.go new file mode 100644 index 00000000..6d70b427 --- /dev/null +++ b/internal/pmtud/ip/ipv4_unix.go @@ -0,0 +1,9 @@ +//go:build linux || darwin + +package ip + +import "syscall" + +func SetIPv4HeaderIncluded(fd int) error { + return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) +} diff --git a/internal/pmtud/ip/ipv4_unspecified.go b/internal/pmtud/ip/ipv4_unspecified.go new file mode 100644 index 00000000..09469021 --- /dev/null +++ b/internal/pmtud/ip/ipv4_unspecified.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows && !darwin + +package ip + +func SetIPv4HeaderIncluded(fd int) error { + panic("not implemented") +} diff --git a/internal/pmtud/ip/ipv4_windows.go b/internal/pmtud/ip/ipv4_windows.go new file mode 100644 index 00000000..66263b6b --- /dev/null +++ b/internal/pmtud/ip/ipv4_windows.go @@ -0,0 +1,12 @@ +package ip + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func SetIPv4HeaderIncluded(handle syscall.Handle) error { + const ipHdrIncluded = windows.IP_HDRINCL + return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, ipHdrIncluded, 1) +} diff --git a/internal/pmtud/ip/ipv6_darwin.go b/internal/pmtud/ip/ipv6_darwin.go new file mode 100644 index 00000000..40d8b03d --- /dev/null +++ b/internal/pmtud/ip/ipv6_darwin.go @@ -0,0 +1,5 @@ +package ip + +func SetIPv6HeaderIncluded(fd int) error { + panic("darwin does not allow an application to build IPv6 headers") +} diff --git a/internal/pmtud/ip/ipv6_linux.go b/internal/pmtud/ip/ipv6_linux.go new file mode 100644 index 00000000..8b1ba81b --- /dev/null +++ b/internal/pmtud/ip/ipv6_linux.go @@ -0,0 +1,8 @@ +package ip + +import "syscall" + +func SetIPv6HeaderIncluded(fd int) error { + const ipv6HdrIncluded = 36 // IPV6_HDRINCL + return syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, ipv6HdrIncluded, 1) +} diff --git a/internal/pmtud/ip/ipv6_unspecified.go b/internal/pmtud/ip/ipv6_unspecified.go new file mode 100644 index 00000000..b4ea2b2a --- /dev/null +++ b/internal/pmtud/ip/ipv6_unspecified.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows && !darwin + +package ip + +func SetIPv6HeaderIncluded(fd int) error { + panic("not implemented") +} diff --git a/internal/pmtud/ip/ipv6_windows.go b/internal/pmtud/ip/ipv6_windows.go new file mode 100644 index 00000000..882db5cf --- /dev/null +++ b/internal/pmtud/ip/ipv6_windows.go @@ -0,0 +1,7 @@ +package ip + +import "syscall" + +func SetIPv6HeaderIncluded(fd syscall.Handle) error { + panic("windows does not allow an application to build IPv6 headers") +} diff --git a/internal/pmtud/ip/source.go b/internal/pmtud/ip/source.go new file mode 100644 index 00000000..64d1705e --- /dev/null +++ b/internal/pmtud/ip/source.go @@ -0,0 +1,123 @@ +package ip + +import ( + "fmt" + "net/netip" + "syscall" + + "github.com/jsimonetti/rtnetlink" +) + +// SrcAddr determines the appropriate source IP address to use when sending a packet to the +// specified destination. It also reserves an ephemeral source port for the specified protocol +// to ensure that the port is not used by other processes. The cleanup function returned should +// be called to release the reserved port when done. +func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) { + srcAddr, err := srcIP(dst.Addr()) + if err != nil { + return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err) + } + + srcPort, cleanup, err := srcPort(srcAddr, proto) + if err != nil { + return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err) + } + + return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil +} + +var errNoRoute = fmt.Errorf("no route to destination") + +func srcIP(dst netip.Addr) (netip.Addr, error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return netip.Addr{}, err + } + defer conn.Close() + + family := uint8(syscall.AF_INET) + if dst.Is6() { + family = syscall.AF_INET6 + } + + // Request route to destination + requestMessage := &rtnetlink.RouteMessage{ + Family: family, + Attributes: rtnetlink.RouteAttributes{ + Dst: dst.AsSlice(), + }, + } + messages, err := conn.Route.Get(requestMessage) + if err != nil { + return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err) + } + + for _, message := range messages { + if message.Attributes.Src == nil { + continue + } + ipv6 := message.Attributes.Src.To4() == nil + if ipv6 { + return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil + } + return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil + } + + return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages)) +} + +// srcPort reserves an ephemeral source port by opening a socket for the +// protocol specified and binds it to the provided source address. +// It doesn't actually listen on the port. +// The cleanup function returned should be called to release the port when done. +func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) { + family := syscall.AF_INET + if srcAddr.Is6() { + family = syscall.AF_INET6 + } + + fd, err := syscall.Socket(family, syscall.SOCK_STREAM, proto) + if err != nil { + return 0, nil, fmt.Errorf("creating reservation socket: %w", err) + } + cleanup = func() { + _ = syscall.Close(fd) + } + + // Bind to port 0 to get an ephemeral port + const port = 0 + var bindAddr syscall.Sockaddr + if srcAddr.Is4() { + bindAddr = &syscall.SockaddrInet4{ + Port: port, + Addr: srcAddr.As4(), + } + } else { + bindAddr = &syscall.SockaddrInet6{ + Port: port, + Addr: srcAddr.As16(), + } + } + + err = syscall.Bind(fd, bindAddr) + if err != nil { + cleanup() + return 0, nil, fmt.Errorf("binding reservation socket: %w", err) + } + + sockAddr, err := syscall.Getsockname(fd) + if err != nil { + cleanup() + return 0, nil, fmt.Errorf("getting bound socket name: %w", err) + } + + switch typedSockAddr := sockAddr.(type) { + case *syscall.SockaddrInet4: + srcPort = uint16(typedSockAddr.Port) //nolint:gosec + case *syscall.SockaddrInet6: + srcPort = uint16(typedSockAddr.Port) //nolint:gosec + default: + panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr)) + } + return srcPort, cleanup, nil +} diff --git a/internal/pmtud/pmtud.go b/internal/pmtud/pmtud.go index 4a42dba7..37e8ef8d 100644 --- a/internal/pmtud/pmtud.go +++ b/internal/pmtud/pmtud.go @@ -4,268 +4,73 @@ import ( "context" "errors" "fmt" - "math" - "net" "net/netip" - "strings" "time" - "golang.org/x/net/icmp" + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/icmp" + "github.com/qdm12/gluetun/internal/pmtud/tcp" ) -var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU") - -// PathMTUDiscover discovers the maximum MTU for the path to the given ip address. +// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP. +// Multiple ICMP addresses and TCP addresses can be specified for redundancy. +// ICMP PMTUD is run first. If successful, the range of possible MTU values to +// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the +// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the +// whole range of possible MTU values up to the physical link MTU to check. // If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU. // If the pingTimeout is zero, it defaults to 1 second. // If the logger is nil, a no-op logger is used. // It returns [ErrMTUNotFound] if the MTU could not be determined. -func PathMTUDiscover(ctx context.Context, ip netip.Addr, - physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) ( +func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, + physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) ( mtu uint32, err error, ) { if physicalLinkMTU == 0 { const ethernetStandardMTU = 1500 physicalLinkMTU = ethernetStandardMTU } - if pingTimeout == 0 { - pingTimeout = time.Second + if tryTimeout == 0 { + tryTimeout = time.Second } if logger == nil { logger = &noopLogger{} } - if ip.Is4() { - logger.Debug("finding IPv4 next hop MTU") - mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger) + // Try finding the MTU using ICMP + maxPossibleMTU := physicalLinkMTU + icmpSuccess := false + for _, icmpIP := range icmpAddrs { + mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU, + tryTimeout, logger) switch { case err == nil: - return mtu, nil - case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole + logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu) + icmpSuccess = true + maxPossibleMTU = mtu + case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound): + logger.Debugf("ICMP path MTU discovery failed: %s", err) default: - return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err) - } - } else { - logger.Debug("requesting IPv6 ICMP packet-too-big reply") - mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger) - switch { - case err == nil: - return mtu, nil - case errors.Is(err, net.ErrClosed): // blackhole - default: - return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err) + return 0, fmt.Errorf("ICMP path MTU discovery: %w", err) } } - // Fall back method: send echo requests with different packet - // sizes and check which ones succeed to find the maximum MTU. - logger.Debug("falling back to sending different sized echo packets") - minMTU := minIPv4MTU - if ip.Is6() { - minMTU = minIPv6MTU - } - return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger) -} - -type pmtudTestUnit struct { - mtu uint32 - echoID uint16 - sentBytes int - ok bool -} - -func pmtudMultiSizes(ctx context.Context, ip netip.Addr, - minMTU, maxPossibleMTU uint32, pingTimeout time.Duration, - logger Logger, -) (maxMTU uint32, err error) { - var ipVersion string - var conn net.PacketConn - if ip.Is4() { - ipVersion = "v4" - conn, err = listenICMPv4(ctx) - } else { - ipVersion = "v6" - conn, err = listenICMPv6(ctx) - } - if err != nil { - if strings.HasSuffix(err.Error(), "socket: operation not permitted") { - err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted) + for _, addrPort := range tcpAddrs { + minMTU := constants.MinIPv4MTU + if addrPort.Addr().Is6() { + minMTU = constants.MinIPv6MTU } - return 0, fmt.Errorf("listening for ICMP packets: %w", err) - } - - mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU) - if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU - return minMTU, nil - } - logger.Debugf("testing the following MTUs: %v", mtusToTest) - - tests := make([]pmtudTestUnit, len(mtusToTest)) - for i := range mtusToTest { - tests[i] = pmtudTestUnit{mtu: mtusToTest[i]} - } - - timedCtx, cancel := context.WithTimeout(ctx, pingTimeout) - defer cancel() - go func() { - <-timedCtx.Done() - conn.Close() - }() - - for i := range tests { - id, message := buildMessageToSend(ipVersion, tests[i].mtu) - tests[i].echoID = id - - encodedMessage, err := message.Marshal(nil) + if icmpSuccess { + const mtuMargin = 150 + minMTU = max(maxPossibleMTU-mtuMargin, minMTU) + } + mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger) if err != nil { - return 0, fmt.Errorf("encoding ICMP message: %w", err) - } - tests[i].sentBytes = len(encodedMessage) - - _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) - if err != nil { - if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { - err = fmt.Errorf("%w", ErrICMPNotPermitted) - } - return 0, fmt.Errorf("writing ICMP message: %w", err) - } - } - - err = collectReplies(conn, ipVersion, tests, logger) - switch { - case err == nil: // max possible MTU is working - return tests[len(tests)-1].mtu, nil - case err != nil && errors.Is(err, net.ErrClosed): - // we have timeouts (IPv4 testing or IPv6 PMTUD blackholes) - // so find the highest MTU which worked. - // Note we start from index len(tests) - 2 since the max MTU - // cannot be working if we had a timeout. - for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd - if tests[i].ok { - return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1, - pingTimeout, logger) - } - } - - // All MTUs failed. - return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound) - case err != nil: - return 0, fmt.Errorf("collecting ICMP echo replies: %w", err) - default: - panic("unreachable") - } -} - -// Create the MTU slice of length 11 such that: -// - the first element is the minMTU -// - the last element is the maxMTU -// - elements in-between are separated as close to each other -// The number 11 is chosen to find the final MTU in 3 searches, -// with a total search space of 1728 MTUs which is enough; -// to find it in 2 searches requires 37 parallel queries which -// could be blocked by firewalls. -func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) { - const mtusLength = 11 // find the final MTU in 3 searches - diff := maxMTU - minMTU - switch { - case minMTU > maxMTU: - panic("minMTU > maxMTU") - case diff <= mtusLength: - mtus = make([]uint32, 0, diff) - for mtu := minMTU; mtu <= maxMTU; mtu++ { - mtus = append(mtus, mtu) - } - default: - step := float64(diff) / float64(mtusLength-1) - mtus = make([]uint32, 0, mtusLength) - for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step { - mtus = append(mtus, uint32(math.Round(mtu))) - } - mtus = append(mtus, maxMTU) // last element is the maxMTU - } - - return mtus -} - -func collectReplies(conn net.PacketConn, ipVersion string, - tests []pmtudTestUnit, logger Logger, -) (err error) { - echoIDToTestIndex := make(map[uint16]int, len(tests)) - for i, test := range tests { - echoIDToTestIndex[test.echoID] = i - } - - // The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would - // create huge buffers which we don't really want to support anyway. - // The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with - // a conventional maximum of 9000 bytes. However, some manufacturers support up - // 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to - // match eventual Jumbo frames. More information at: - // https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media - const maxPossibleMTU = 9196 - buffer := make([]byte, maxPossibleMTU) - - idsFound := 0 - for idsFound < len(tests) { - // Note we need to read the whole packet in one call to ReadFrom, so the buffer - // must be large enough to read the entire reply packet. See: - // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J - bytesRead, _, err := conn.ReadFrom(buffer) - if err != nil { - return fmt.Errorf("reading from ICMP connection: %w", err) - } - packetBytes := buffer[:bytesRead] - - ipPacketLength := len(packetBytes) - - var icmpProtocol int - switch ipVersion { - case "v4": - icmpProtocol = icmpv4Protocol - case "v6": - icmpProtocol = icmpv6Protocol - default: - panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) - } - - // Parse the ICMP message - // Note: this parsing works for a truncated 556 bytes ICMP reply packet. - message, err := icmp.ParseMessage(icmpProtocol, packetBytes) - if err != nil { - return fmt.Errorf("parsing message: %w", err) - } - - echoBody, ok := message.Body.(*icmp.Echo) - if !ok { - return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body) - } - - id := uint16(echoBody.ID) //nolint:gosec - testIndex, testing := echoIDToTestIndex[id] - if !testing { // not an id we expected so ignore it - logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)", - echoBody.ID, message.Type, message.Code, ipPacketLength) + logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err) continue } - idsFound++ - sentBytes := tests[testIndex].sentBytes - - // echo reply should be at most the number of bytes sent, - // and can be lower, more precisely 556 bytes, in case - // the host we are reaching wants to stay out of trouble - // and ensure its echo reply goes through without - // fragmentation, see the following page: - // https://datatracker.ietf.org/doc/html/rfc1122#page-59 - const conservativeReplyLength = 556 - truncated := ipPacketLength < sentBytes && - ipPacketLength == conservativeReplyLength - // Check the packet size is the same if the reply is not truncated - if !truncated && sentBytes != ipPacketLength { - return fmt.Errorf("%w: sent %dB and received %dB", - ErrICMPEchoDataMismatch, sentBytes, ipPacketLength) - } - // Truncated reply or matching reply size - tests[testIndex].ok = true + logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu) + return mtu, nil } - return nil + return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err) } diff --git a/internal/pmtud/tcp/interfaces.go b/internal/pmtud/tcp/interfaces.go new file mode 100644 index 00000000..2709d75f --- /dev/null +++ b/internal/pmtud/tcp/interfaces.go @@ -0,0 +1,7 @@ +package tcp + +type Logger interface { + Debug(msg string) + Debugf(msg string, args ...any) + Warnf(msg string, args ...any) +} diff --git a/internal/pmtud/tcp/mocks_generate_test.go b/internal/pmtud/tcp/mocks_generate_test.go new file mode 100644 index 00000000..fb93ff6a --- /dev/null +++ b/internal/pmtud/tcp/mocks_generate_test.go @@ -0,0 +1,3 @@ +package tcp + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger diff --git a/internal/pmtud/tcp/mocks_test.go b/internal/pmtud/tcp/mocks_test.go new file mode 100644 index 00000000..4766e4f4 --- /dev/null +++ b/internal/pmtud/tcp/mocks_test.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger) + +// Package tcp is a generated GoMock package. +package tcp + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0) +} + +// Debugf mocks base method. +func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Warnf mocks base method. +func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf. +func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...) +} diff --git a/internal/pmtud/tcp/multi.go b/internal/pmtud/tcp/multi.go new file mode 100644 index 00000000..9ac48a94 --- /dev/null +++ b/internal/pmtud/tcp/multi.go @@ -0,0 +1,89 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net/netip" + "syscall" + "time" + + "github.com/qdm12/gluetun/internal/pmtud/test" +) + +var ErrMTUNotFound = errors.New("MTU not found") + +type testUnit struct { + mtu uint32 + ok bool +} + +func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort, + minMTU, maxPossibleMTU uint32, logger Logger, +) (mtu uint32, err error) { + mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU) + if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU + return minMTU, nil + } + logger.Debugf("TCP testing the following MTUs: %v", mtusToTest) + + tests := make([]testUnit, len(mtusToTest)) + for i := range mtusToTest { + tests[i] = testUnit{mtu: mtusToTest[i]} + } + + family := syscall.AF_INET + if addrPort.Addr().Is6() { + family = syscall.AF_INET6 + } + fd, stop, err := startRawSocket(family) + if err != nil { + return 0, fmt.Errorf("starting raw socket: %w", err) + } + defer stop() + + tracker := newTracker(fd, addrPort.Addr().Is4()) + + const timeout = time.Second + runCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + errCh := make(chan error) + go func() { + errCh <- tracker.listen(runCtx) + }() + + doneCh := make(chan struct{}) + for i := range tests { + go func(i int) { + err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu) + tests[i].ok = err == nil + doneCh <- struct{}{} + }(i) + } + + for range tests { + select { + case <-doneCh: + case err := <-errCh: + if err == nil { // timeout + break + } + return 0, fmt.Errorf("listening for TCP replies: %w", err) + } + } + + if tests[len(tests)-1].ok { + return tests[len(tests)-1].mtu, nil + } + + for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd + if tests[i].ok { + stop() + cancel() + return PathMTUDiscover(ctx, addrPort, + tests[i].mtu, tests[i+1].mtu-1, logger) + } + } + + return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound) +} diff --git a/internal/pmtud/tcp/packet.go b/internal/pmtud/tcp/packet.go new file mode 100644 index 00000000..f9e1b621 --- /dev/null +++ b/internal/pmtud/tcp/packet.go @@ -0,0 +1,89 @@ +package tcp + +import ( + "math/rand/v2" + "net/netip" + "syscall" + + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/ip" +) + +// createSYNPacket creates a TCP SYN packet for initiating a handshake. +// SYN packets have normally no data payload, so you SHOULD set mtu to 0. +// However, in some cases where the server closes the connection with RST immediately, +// it can be useful to add some data payload to a SYN packet and check if the server still +// replies. Only set mtu to a non zero value if you know what you are doing. +func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) { + seq = rand.Uint32() //nolint:gosec + const ack = 0 // SYN has no ACK number + payloadLength := constants.BaseTCPHeaderLength // no data payload + if mtu > 0 { + payloadLength = getPayloadLength(mtu, dst) + } + return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq +} + +// createACKPacket creates a TCP ACK packet. +// If the mtu is set to 0, no payload is sent. +// Otherwise, the payload is calculated to test the MTU given. +func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte { + payloadLength := constants.BaseTCPHeaderLength // no data payload + if mtu > 0 { + payloadLength = getPayloadLength(mtu, dst) + } + const flags = ackFlag | pshFlag + return createPacket(src, dst, seq, ack, payloadLength, flags) +} + +func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte { + const payloadLength = constants.BaseTCPHeaderLength // no data payload + return createPacket(src, dst, seq, ack, payloadLength, rstFlag) +} + +func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 { + var ipHeaderLength uint32 + if dst.Addr().Is4() { + ipHeaderLength = constants.IPv4HeaderLength + } else { + ipHeaderLength = constants.IPv6HeaderLength + } + if mtu < ipHeaderLength+constants.BaseTCPHeaderLength { + panic("MTU too small to hold IP and TCP headers") + } + return mtu - ipHeaderLength +} + +func createPacket(src, dst netip.AddrPort, + seq, ack, payloadLength uint32, flags byte, +) []byte { + if payloadLength < constants.BaseTCPHeaderLength { + panic("payload length is too small to hold TCP header") + } + + var ipHeader []byte + if dst.Addr().Is4() { + ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength) + } else { + ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(), + uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec + } + + tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags) + + // data is just zeroes + dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength) + var data []byte + if dataLength > 0 { + data = make([]byte, dataLength) + } + checksum := tcpChecksum(ipHeader, tcpHeader, data) + tcpHeader[16] = byte(checksum >> 8) //nolint:mnd + tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd + + packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength) + copy(packet, ipHeader) + copy(packet[len(ipHeader):], tcpHeader) + copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data) + return packet +} diff --git a/internal/pmtud/tcp/tcp.go b/internal/pmtud/tcp/tcp.go new file mode 100644 index 00000000..79decfb9 --- /dev/null +++ b/internal/pmtud/tcp/tcp.go @@ -0,0 +1,196 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net/netip" + "syscall" + + "github.com/qdm12/gluetun/internal/pmtud/constants" + "github.com/qdm12/gluetun/internal/pmtud/ip" +) + +func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) { + fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP) + if err != nil { + return 0, nil, fmt.Errorf("creating raw socket: %w", err) + } + + if family == syscall.AF_INET { + err = ip.SetIPv4HeaderIncluded(fdPlatform) + } else { + err = ip.SetIPv6HeaderIncluded(fdPlatform) + } + if err != nil { + _ = syscall.Close(fdPlatform) + return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err) + } + + // Allow sending packets larger than cached PMTU (for PMTUD probing) + err = setMTUDiscovery(fdPlatform) + if err != nil { + _ = syscall.Close(fdPlatform) + return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err) + } + + // use polling because some Linux systems do not cancel + // blocking syscalls such as recvfrom when the socket is closed, + // which would cause things to hang indefinitely. + err = setNonBlock(fdPlatform) + if err != nil { + _ = syscall.Close(fdPlatform) + return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err) + } + + stop = func() { + _ = syscall.Close(fdPlatform) + } + return fileDescriptor(fdPlatform), stop, nil +} + +var ( + errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK") + errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value") + errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected") +) + +// Craft and send a raw TCP packet to test the MTU. +// It expects either an RST reply (if no server is listening) +// or a SYN-ACK/ACK reply (if a server is listening). +func runTest(ctx context.Context, fd fileDescriptor, + tracker *tracker, dst netip.AddrPort, mtu uint32, +) error { + const proto = syscall.IPPROTO_TCP + src, cleanup, err := ip.SrcAddr(dst, proto) + if err != nil { + return fmt.Errorf("getting source address: %w", err) + } + defer cleanup() + + ch := make(chan []byte) + abort := make(chan struct{}) + defer close(abort) + tracker.register(src.Port(), dst.Port(), ch, abort) + defer tracker.unregister(src.Port(), dst.Port()) + + dstSockAddr := makeSockAddr(dst) + + synPacket, synSeq := createSYNPacket(src, dst, 0) + const sendToFlags = 0 + err = sendTo(fd, synPacket, sendToFlags, dstSockAddr) + if err != nil { + return fmt.Errorf("sending SYN packet: %w", err) + } + + var reply []byte + select { + case <-ctx.Done(): + return ctx.Err() + case reply = <-ch: + } + + packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + switch { + case err != nil: + return fmt.Errorf("parsing first reply TCP header: %w", err) + case packetType == packetTypeRST: + // server actively closed the connection, try sending a SYN with data + return handleRSTReply(ctx, fd, ch, src, dst, mtu) + case packetType != packetTypeSYNACK: + return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType) + case synAckAck != synSeq+1: + return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck) + } + + // Send a no-data ACK packet to finish the 3-way handshake. + const ackMTU = 0 // no data payload initially + ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU) + err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr) + if err != nil { + return fmt.Errorf("sending ACK-without-data packet: %w", err) + } + + // Send a data ACK packet to test the MTU given. + ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu) + err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr) + if err != nil { + return fmt.Errorf("sending ACK-with-data packet: %w", err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case reply = <-ch: + } + + packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + if err != nil { + return fmt.Errorf("parsing second reply TCP header: %w", err) + } + + switch packetType { //nolint:exhaustive + case packetTypeRST: + return nil + case packetTypeACK: + err = sendRST(fd, src, dst, ack) + if err != nil { + return fmt.Errorf("sending RST packet: %w", err) + } + return nil + default: + _ = sendRST(fd, src, dst, ack) + return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType) + } +} + +func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr { + if addr.Addr().Is4() { + return &syscall.SockaddrInet4{ + Port: int(addr.Port()), + Addr: addr.Addr().As4(), + } + } + return &syscall.SockaddrInet6{ + Port: int(addr.Port()), + Addr: addr.Addr().As16(), + } +} + +var errTCPPacketNotRST = errors.New("TCP packet is not an RST") + +func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte, + src, dst netip.AddrPort, mtu uint32, +) error { + packet, _ := createSYNPacket(src, dst, mtu) + const sendToFlags = 0 + err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst)) + if err != nil { + return fmt.Errorf("sending SYN MTU-test packet: %w", err) + } + + var reply []byte + select { + case <-ctx.Done(): + return ctx.Err() // timeout: the MTU test SYN packet was too big + case reply = <-ch: + } + + packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength]) + if err != nil { + return fmt.Errorf("parsing reply TCP header: %w", err) + } else if packetType != packetTypeRST { + return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType) + } + return nil +} + +func sendRST(fd fileDescriptor, src, dst netip.AddrPort, + previousACK uint32, +) error { + seq := previousACK + const ack = 0 + rstPacket := createRSTPacket(src, dst, seq, ack) + const sendToFlags = 0 + return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst)) +} diff --git a/internal/pmtud/tcp/tcp_darwin.go b/internal/pmtud/tcp/tcp_darwin.go new file mode 100644 index 00000000..6a5533a4 --- /dev/null +++ b/internal/pmtud/tcp/tcp_darwin.go @@ -0,0 +1,5 @@ +package tcp + +func stripIPv4Header(reply []byte) (result []byte, ok bool) { + return reply, true +} diff --git a/internal/pmtud/tcp/tcp_linux.go b/internal/pmtud/tcp/tcp_linux.go new file mode 100644 index 00000000..4ddbb81c --- /dev/null +++ b/internal/pmtud/tcp/tcp_linux.go @@ -0,0 +1,7 @@ +package tcp + +import "syscall" + +func setMTUDiscovery(fd int) error { + return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE) +} diff --git a/internal/pmtud/tcp/tcp_notdarwin.go b/internal/pmtud/tcp/tcp_notdarwin.go new file mode 100644 index 00000000..0a49d355 --- /dev/null +++ b/internal/pmtud/tcp/tcp_notdarwin.go @@ -0,0 +1,30 @@ +//go:build !darwin + +package tcp + +import ( + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +func stripIPv4Header(reply []byte) (result []byte, ok bool) { + if len(reply) < int(constants.IPv4HeaderLength) { + return nil, false // not an IPv4 packet + } + + version := reply[0] >> 4 //nolint:mnd + const ipv4Version = 4 + if version != ipv4Version { + return nil, false + } + // For IPv4 we need to skip the IP header, which is at least + // 20B and can be up to 60B. + // The Internet Header Length is the lower 4 bits of the first byte and + // represents the number of 32-bit words of the header length. + const ihlMask byte = 0x0F + const bytesInWord = 4 + headerLength := int((reply[0] & ihlMask)) * bytesInWord + if len(reply) < headerLength { + return nil, false // not enough data for full IPv4 header + } + return reply[headerLength:], true +} diff --git a/internal/pmtud/tcp/tcp_test.go b/internal/pmtud/tcp/tcp_test.go new file mode 100644 index 00000000..a87a9c73 --- /dev/null +++ b/internal/pmtud/tcp/tcp_test.go @@ -0,0 +1,199 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "net/netip" + "syscall" + "testing" + "time" + + "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/routing" + "github.com/qdm12/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_runTest(t *testing.T) { + t.Parallel() + + noopLogger := &noopLogger{} + netlinker := netlink.New(noopLogger) + loopbackMTU, err := findLoopbackMTU(netlinker) + require.NoError(t, err, "finding loopback IPv4 MTU") + defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker) + require.NoError(t, err, "finding default IPv4 route MTU") + + ctx, cancel := context.WithCancel(t.Context()) + + const family = syscall.AF_INET + fd, stop, err := startRawSocket(family) + require.NoError(t, err) + + const ipv4 = true + tracker := newTracker(fd, ipv4) + trackerCh := make(chan error) + go func() { + trackerCh <- tracker.listen(ctx) + }() + + t.Cleanup(func() { + stop() + cancel() // stop listening + err = <-trackerCh + require.NoError(t, err) + }) + + testCases := map[string]struct { + timeout time.Duration + dst func(t *testing.T) netip.AddrPort + mtu uint32 + success bool + }{ + "local_not_listening": { + timeout: time.Hour, + dst: func(t *testing.T) netip.AddrPort { + t.Helper() + port := reserveClosedPort(t) + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port) + }, + mtu: loopbackMTU, + success: true, + }, + "remote_not_listening": { + timeout: 50 * time.Millisecond, + dst: func(_ *testing.T) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345) + }, + mtu: defaultIPv4MTU, + }, + "1.1.1.1:443": { + timeout: time.Second, + dst: func(_ *testing.T) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) + }, + mtu: defaultIPv4MTU, + success: true, + }, + "1.1.1.1:80": { + timeout: time.Second, + dst: func(_ *testing.T) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80) + }, + mtu: defaultIPv4MTU, + success: true, + }, + "8.8.8.8:443": { + timeout: time.Second, + dst: func(_ *testing.T) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443) + }, + mtu: defaultIPv4MTU, + success: true, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout) + defer cancel() + dst := testCase.dst(t) + err := runTest(ctx, fd, tracker, dst, testCase.mtu) + if testCase.success { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +var errRouteNotFound = errors.New("route not found") + +func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { + routes, err := netlinker.RouteList(netlink.FamilyV4) + if err != nil { + return 0, fmt.Errorf("getting routes list: %w", err) + } + for _, route := range routes { + if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() { + link, err := netlinker.LinkByIndex(route.LinkIndex) + if err != nil { + return 0, fmt.Errorf("getting link by index: %w", err) + } + // Quirk: make sure it is maximum 65535, and not i.e. 65536 + // or the IP header 16 bits will fail to fit that packet length value. + const maxMTU = 65535 + return min(link.MTU, maxMTU), nil + } + } + return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound) +} + +func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) { + noopLogger := &noopLogger{} + routing := routing.New(netlinker, noopLogger) + defaultRoutes, err := routing.DefaultRoutes() + if err != nil { + return 0, fmt.Errorf("getting default routes: %w", err) + } + for _, route := range defaultRoutes { + if route.Family != netlink.FamilyV4 { + continue + } + link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface) + if err != nil { + return 0, fmt.Errorf("getting link by name: %w", err) + } + return link.MTU, nil + } + return 0, fmt.Errorf("%w: no default route found", errRouteNotFound) +} + +func reserveClosedPort(t *testing.T) (port uint16) { + t.Helper() + + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + require.NoError(t, err) + t.Cleanup(func() { + err := syscall.Close(fd) + assert.NoError(t, err) + }) + + addr := &syscall.SockaddrInet4{ + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + + err = syscall.Bind(fd, addr) + if err != nil { + _ = syscall.Close(fd) + t.Fatal(err) + } + + sockAddr, err := syscall.Getsockname(fd) + if err != nil { + _ = syscall.Close(fd) + t.Fatal(err) + } + + sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4) + if !ok { + _ = syscall.Close(fd) + t.Fatal("not an IPv4 address") + } + + return uint16(sockAddr4.Port) //nolint:gosec +} + +type noopLogger struct{} + +func (l *noopLogger) Patch(_ ...log.Option) {} +func (l *noopLogger) Debug(_ string) {} +func (l *noopLogger) Debugf(_ string, _ ...any) {} +func (l *noopLogger) Info(_ string) {} +func (l *noopLogger) Warn(_ string) {} +func (l *noopLogger) Error(_ string) {} diff --git a/internal/pmtud/tcp/tcp_unix.go b/internal/pmtud/tcp/tcp_unix.go new file mode 100644 index 00000000..0b2b24ea --- /dev/null +++ b/internal/pmtud/tcp/tcp_unix.go @@ -0,0 +1,28 @@ +//go:build linux || darwin + +package tcp + +import ( + "syscall" + "time" +) + +// fileDescriptor is a platform-independent type for socket file descriptors. +type fileDescriptor int + +func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) { + return syscall.Sendto(int(fd), p, flags, to) +} + +func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) { + timeval := syscall.NsecToTimeval(timeout.Nanoseconds()) + return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval) +} + +func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) { + return syscall.Recvfrom(int(fd), p, flags) +} + +func setNonBlock(fd int) error { + return syscall.SetNonblock(fd, true) +} diff --git a/internal/pmtud/tcp/tcp_unspecified.go b/internal/pmtud/tcp/tcp_unspecified.go new file mode 100644 index 00000000..ff50b22e --- /dev/null +++ b/internal/pmtud/tcp/tcp_unspecified.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package tcp + +func setMTUDiscovery(fd int) error { + panic("not implemented") +} diff --git a/internal/pmtud/tcp/tcp_windows.go b/internal/pmtud/tcp/tcp_windows.go new file mode 100644 index 00000000..9cab5507 --- /dev/null +++ b/internal/pmtud/tcp/tcp_windows.go @@ -0,0 +1,37 @@ +package tcp + +import ( + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type fileDescriptor syscall.Handle + +func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) { + return syscall.Sendto(syscall.Handle(fd), p, flags, to) +} + +func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) { + timeval := int(timeout.Milliseconds()) + return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval) +} + +func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) { + return syscall.Recvfrom(syscall.Handle(fd), p, flags) +} + +func setMTUDiscovery(fd syscall.Handle) error { + panic("not implemented") +} + +func setNonBlock(fd syscall.Handle) error { + // Windows: Use ioctlsocket with FIONBIO + var arg uint32 = 1 // 1 to enable non-blocking mode + var bytesReturned uint32 + const FIONBIO = 0x8004667e + return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)), + uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0) +} diff --git a/internal/pmtud/tcp/tcpheader.go b/internal/pmtud/tcp/tcpheader.go new file mode 100644 index 00000000..f9deb13b --- /dev/null +++ b/internal/pmtud/tcp/tcpheader.go @@ -0,0 +1,124 @@ +package tcp + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +// For SYN, ack is 0. +// For SYN-ACK, ack is the sequence number + 1 of the SYN. +func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte { + header := make([]byte, constants.BaseTCPHeaderLength) + binary.BigEndian.PutUint16(header[0:], srcPort) + binary.BigEndian.PutUint16(header[2:], dstPort) + binary.BigEndian.PutUint32(header[4:], seq) + binary.BigEndian.PutUint32(header[8:], ack) + //nolint:mnd + header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset + header[13] = flags + // windowSize can be left to 5840 even for IPv6, it doesn't matter. + const windowSize = 5840 + binary.BigEndian.PutUint16(header[14:], windowSize) + // header[16:17] is the checksum, set later + // header[18:19] is urgent pointer, not needed for our use case + return header +} + +//nolint:mnd +func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 { + var pseudoHeader []byte + isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6 + if isIPv6 { + pseudoHeader = make([]byte, 40) + copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address + copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address + totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec + binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength) + pseudoHeader[39] = 6 // Next Header (TCP) + } else { + pseudoHeader = make([]byte, 12) + copy(pseudoHeader[0:4], ipHeader[12:16]) + copy(pseudoHeader[4:8], ipHeader[16:20]) + pseudoHeader[9] = 6 + totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec + binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength) + } + + sum := uint32(0) + for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} { + for i := 0; i < len(slice)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(slice[i : i+2])) + } + if len(slice)%2 != 0 { + sum += uint32(slice[len(slice)-1]) << 8 + } + } + for (sum >> 16) > 0 { + sum = (sum & 0xFFFF) + (sum >> 16) + } + return ^uint16(sum) //nolint:gosec +} + +const ( + tcpFlagsOffset = 13 + rstFlag byte = 0x04 + synFlag byte = 0x02 + ackFlag byte = 0x10 + pshFlag byte = 0x08 +) + +type packetType uint8 + +const ( + packetTypeSYN packetType = iota + 1 + packetTypeSYNACK + packetTypeACK + packetTypeRST +) + +func (p packetType) String() string { + switch p { + case packetTypeSYN: + return "SYN" + case packetTypeSYNACK: + return "SYN-ACK" + case packetTypeACK: + return "ACK" + case packetTypeRST: + return "RST" + default: + panic("unknown packet type") + } +} + +var ( + errTCPHeaderTooShort = errors.New("TCP header is too short") + errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown") +) + +// parseTCPHeader parses some elements from the TCP header. +func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) { + if len(header) < int(constants.BaseTCPHeaderLength) { + return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header)) + } + flags := header[tcpFlagsOffset] + switch { + case (flags&synFlag) != 0 && (flags&ackFlag) == 0: + packetType = packetTypeSYN + case (flags&synFlag) != 0 && (flags&ackFlag) != 0: + packetType = packetTypeSYNACK + case (flags & rstFlag) != 0: + packetType = packetTypeRST + case (flags & ackFlag) != 0: + packetType = packetTypeACK + default: + return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags) + } + + seq = binary.BigEndian.Uint32(header[4:8]) + ack = binary.BigEndian.Uint32(header[8:12]) + return packetType, seq, ack, nil +} diff --git a/internal/pmtud/tcp/tracker.go b/internal/pmtud/tcp/tracker.go new file mode 100644 index 00000000..13a7167b --- /dev/null +++ b/internal/pmtud/tcp/tracker.go @@ -0,0 +1,134 @@ +package tcp + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + "syscall" + "time" + + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +type tracker struct { + fd fileDescriptor + ipv4 bool + mutex sync.RWMutex + portsToDispatch map[uint32]dispatch +} + +type dispatch struct { + replyCh chan<- []byte + abort <-chan struct{} +} + +func newTracker(fd fileDescriptor, ipv4 bool) *tracker { + return &tracker{ + fd: fd, + ipv4: ipv4, + portsToDispatch: make(map[uint32]dispatch), + } +} + +func (t *tracker) constructKey(localPort, remotePort uint16) uint32 { + buf := make([]byte, 4) //nolint:mnd + binary.BigEndian.PutUint16(buf[0:2], localPort) + binary.BigEndian.PutUint16(buf[2:4], remotePort) + return binary.BigEndian.Uint32(buf) +} + +func (t *tracker) register(localPort, remotePort uint16, + ch chan<- []byte, abort <-chan struct{}, +) { + key := t.constructKey(localPort, remotePort) + t.mutex.Lock() + defer t.mutex.Unlock() + t.portsToDispatch[key] = dispatch{ + replyCh: ch, + abort: abort, + } +} + +func (t *tracker) unregister(localPort, remotePort uint16) { + key := t.constructKey(localPort, remotePort) + t.mutex.Lock() + defer t.mutex.Unlock() + delete(t.portsToDispatch, key) +} + +// listen listens for incoming TCP packets and dispatches them to the +// correct channel based on the source and destination port. +// If the context has a deadline associated, this one is used on the socket. +// Note it returns a nil error on context cancellation. +func (t *tracker) listen(ctx context.Context) error { + deadline, hasDeadline := ctx.Deadline() + for ctx.Err() == nil { + if hasDeadline { + remaining := time.Until(deadline) + if remaining <= 0 { + return nil + } + err := setSocketTimeout(t.fd, remaining) + if err != nil { + return fmt.Errorf("setting socket receive timeout: %w", err) + } + } + + reply := make([]byte, constants.MaxEthernetFrameSize) + n, _, err := recvFrom(t.fd, reply, 0) + if err != nil { + switch { + case errors.Is(err, syscall.EAGAIN), + errors.Is(err, syscall.EWOULDBLOCK): + pollSleep(ctx) + continue + case ctx.Err() != nil: + // context canceled, stop listening so exit cleanly with no error + return nil //nolint:nilerr + default: + return fmt.Errorf("receiving on socket: %w", err) + } + } + reply = reply[:n] + + if t.ipv4 { + var ok bool + reply, ok = stripIPv4Header(reply) + if !ok { + continue // not an IPv4 packet + } + } + + const minTCPHeaderLength = 20 + if len(reply) < minTCPHeaderLength { + continue + } + + srcPort := binary.BigEndian.Uint16(reply[0:2]) + dstPort := binary.BigEndian.Uint16(reply[2:4]) + key := t.constructKey(dstPort, srcPort) + t.mutex.RLock() + dispatch, exists := t.portsToDispatch[key] + t.mutex.RUnlock() + if !exists { + continue + } + select { + case dispatch.replyCh <- reply: + case <-dispatch.abort: + } + } + return nil +} + +func pollSleep(ctx context.Context) { + const sleepBetweenPolls = 10 * time.Millisecond + timer := time.NewTimer(sleepBetweenPolls) + select { + case <-ctx.Done(): + timer.Stop() + case <-timer.C: + } +} diff --git a/internal/pmtud/test/mtu.go b/internal/pmtud/test/mtu.go new file mode 100644 index 00000000..e66976fb --- /dev/null +++ b/internal/pmtud/test/mtu.go @@ -0,0 +1,36 @@ +package test + +import "math" + +// MakeMTUsToTest determines a slice of MTU values to test +// between minMTU and maxMTU inclusive. It creates an MTU +// slice of length up to 11 MTUs such that: +// - the first element is the minMTU +// - the last element is the maxMTU +// - elements in-between are separated as close to each other +// The number 11 is chosen to find the final MTU in 3 searches, +// with a total search space of 1728 MTUs which is enough; +// to find it in 2 searches requires 37 parallel queries which +// could be blocked by firewalls. +func MakeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) { + const mtusLength = 11 // find the final MTU in 3 searches + diff := maxMTU - minMTU + switch { + case minMTU > maxMTU: + panic("minMTU > maxMTU") + case diff <= mtusLength: + mtus = make([]uint32, 0, diff) + for mtu := minMTU; mtu <= maxMTU; mtu++ { + mtus = append(mtus, mtu) + } + default: + step := float64(diff) / float64(mtusLength-1) + mtus = make([]uint32, 0, mtusLength) + for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step { + mtus = append(mtus, uint32(math.Round(mtu))) + } + mtus = append(mtus, maxMTU) // last element is the maxMTU + } + + return mtus +} diff --git a/internal/pmtud/pmtud_test.go b/internal/pmtud/test/mtu_test.go similarity index 88% rename from internal/pmtud/pmtud_test.go rename to internal/pmtud/test/mtu_test.go index db187821..a43bb565 100644 --- a/internal/pmtud/pmtud_test.go +++ b/internal/pmtud/test/mtu_test.go @@ -1,4 +1,4 @@ -package pmtud +package test import ( "testing" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_makeMTUsToTest(t *testing.T) { +func Test_MakeMTUsToTest(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -48,7 +48,7 @@ func Test_makeMTUsToTest(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU) + mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU) assert.Equal(t, testCase.mtus, mtus) }) } diff --git a/internal/pmtud/vpn.go b/internal/pmtud/vpn.go new file mode 100644 index 00000000..935928f4 --- /dev/null +++ b/internal/pmtud/vpn.go @@ -0,0 +1,40 @@ +package pmtud + +import ( + "net/netip" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/constants/vpn" + pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel +// given the VPN type, network protocol, and VPN gateway IP address. +// This is notably useful to skip testing MTU values higher than this value. +// The function panics if the network or VPN type is unknown. +func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 { + const physicalLinkMTU = pconstants.MaxEthernetFrameSize + vpnLinkMTU := physicalLinkMTU + if vpnGateway.Is4() { + vpnLinkMTU -= pconstants.IPv4HeaderLength + } else { + vpnLinkMTU -= pconstants.IPv6HeaderLength + } + switch network { + case constants.TCP: + vpnLinkMTU -= pconstants.BaseTCPHeaderLength + case constants.UDP: + vpnLinkMTU -= pconstants.UDPHeaderLength + default: + panic("unknown network protocol: " + network) + } + switch vpnType { + case vpn.Wireguard: + vpnLinkMTU -= pconstants.WireguardHeaderLength + case vpn.OpenVPN: + vpnLinkMTU -= pconstants.OpenVPNHeaderMaxLength + default: + panic("unknown VPN type: " + vpnType) + } + return vpnLinkMTU +} diff --git a/internal/provider/utils/wireguard.go b/internal/provider/utils/wireguard.go index f3b205cc..05c2df8a 100644 --- a/internal/provider/utils/wireguard.go +++ b/internal/provider/utils/wireguard.go @@ -16,7 +16,17 @@ func BuildWireguardSettings(connection models.Connection, settings.PreSharedKey = *userSettings.PreSharedKey settings.InterfaceName = userSettings.Interface settings.Implementation = userSettings.Implementation - settings.MTU = userSettings.MTU + if *userSettings.MTU > 0 { + settings.MTU = *userSettings.MTU + } else { + // The default is 1320 which is NOT the wireguard-go default + // of 1420 because this impacts bandwidth a lot on some + // VPN providers, see https://github.com/qdm12/gluetun/issues/1650. + // It has been lowered to 1320 following quite a bit of + // investigation in the issue: https://github.com/qdm12/gluetun/issues/2533. + const defaultMTU = 1320 + settings.MTU = defaultMTU + } settings.IPv6 = &ipv6Supported const rulePriority = 101 // 100 is to receive external connections diff --git a/internal/provider/utils/wireguard_test.go b/internal/provider/utils/wireguard_test.go index 61303e4d..42d973b3 100644 --- a/internal/provider/utils/wireguard_test.go +++ b/internal/provider/utils/wireguard_test.go @@ -22,7 +22,7 @@ func Test_BuildWireguardSettings(t *testing.T) { ipv6Supported bool settings wireguard.Settings }{ - "some settings": { + "some_settings": { connection: models.Connection{ IP: netip.AddrFrom4([4]byte{1, 2, 3, 4}), Port: 51821, @@ -41,6 +41,7 @@ func Test_BuildWireguardSettings(t *testing.T) { }, PersistentKeepaliveInterval: ptrTo(time.Hour), Interface: "wg1", + MTU: ptrTo(uint32(1000)), }, ipv6Supported: false, settings: wireguard.Settings{ @@ -58,6 +59,7 @@ func Test_BuildWireguardSettings(t *testing.T) { PersistentKeepaliveInterval: time.Hour, RulePriority: 101, IPv6: boolPtr(false), + MTU: 1000, }, }, } diff --git a/internal/routing/vpn.go b/internal/routing/vpn.go index fc68e148..47e795e1 100644 --- a/internal/routing/vpn.go +++ b/internal/routing/vpn.go @@ -47,3 +47,26 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { } return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes)) } + +var ErrVPNRouteNotFound = errors.New("VPN route not found") + +func (r *Routing) VPNRoute(vpnIntf string) (route netlink.Route, err error) { + vpnLink, err := r.netLinker.LinkByName(vpnIntf) + if err != nil { + return route, fmt.Errorf("finding link %s: %w", vpnIntf, err) + } + vpnLinkIndex := vpnLink.Index + + routes, err := r.netLinker.RouteList(netlink.FamilyAll) + if err != nil { + return route, fmt.Errorf("listing routes: %w", err) + } + for _, route := range routes { + if route.LinkIndex == vpnLinkIndex && + !route.Dst.IsValid() { + return route, nil + } + } + return route, fmt.Errorf("%w: for interface %s in %d routes", + ErrVPNRouteNotFound, vpnIntf, len(routes)) +} diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 3eda60c9..b2a3618c 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -21,6 +21,7 @@ type Firewall interface { type Routing interface { VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) + VPNRoute(vpnIntf string) (route netlink.Route, err error) } type PortForward interface { @@ -67,6 +68,7 @@ type NetLinker interface { type Router interface { RouteList(family uint8) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error + RouteReplace(route netlink.Route) error } type Ruler interface { diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 17c5e0a7..167750a3 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -47,7 +47,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - vpnType: settings.Type, + pmtud: tunnelUpPMTUDData{ + enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0, + vpnType: settings.Type, + network: connection.Protocol, + icmpAddrs: settings.PMTUD.ICMPAddresses, + tcpAddrs: settings.PMTUD.TCPAddresses, + }, serverIP: connection.IP, serverName: connection.ServerName, canPortForward: connection.PortForward, diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 355a7be2..6af4df60 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -2,7 +2,6 @@ package vpn import ( "context" - "errors" "fmt" "net/netip" "time" @@ -10,6 +9,7 @@ import ( "github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/pmtud" + pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/log" ) @@ -17,9 +17,7 @@ import ( type tunnelUpData struct { // Healthcheck serverIP netip.Addr - // vpnType is used for path MTU discovery to find the protocol overhead. - // It can be "wireguard" or "openvpn". - vpnType string + pmtud tunnelUpPMTUDData // Port forwarding vpnIntf string serverName string // used for PIA @@ -29,6 +27,23 @@ type tunnelUpData struct { portForwarder PortForwarder } +type tunnelUpPMTUDData struct { + // enabled is notably false if the user specifies a custom MTU. + enabled bool + // vpnType is used to find the maximum VPN header overhead. + // It can be [vpn.Wireguard] or [vpn.OpenVPN]. + vpnType string + // network is used to find the network level header overhead. + // It can be [constants.UDP] or [constants.TCP]. + network string + // icmpAddrs is the list of addresses to use for ICMP path MTU discovery. + // Each address should handle ICMP packets for PMTUD to work. + icmpAddrs []netip.Addr + // tcpAddrs is the list of addresses to use for TCP path MTU discovery. + // Each address should have a listening TCP server on the port specified. + tcpAddrs []netip.AddrPort +} + func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { l.client.CloseIdleConnections() @@ -39,11 +54,14 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { } } - mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) - err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType, - l.netLinker, l.routing, mtuLogger) - if err != nil { - mtuLogger.Error(err.Error()) + if data.pmtud.enabled { + mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) + err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType, + data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs, + l.netLinker, l.routing, mtuLogger) + if err != nil { + mtuLogger.Error(err.Error()) + } } icmpTargetIPs := l.healthSettings.ICMPTargetIPs @@ -136,12 +154,11 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) { _, _ = l.ApplyStatus(ctx, constants.Running) } -var errVPNTypeUnknown = errors.New("unknown VPN type") - func updateToMaxMTU(ctx context.Context, vpnInterface string, - vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger, + vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort, + netlinker NetLinker, routing Routing, logger *log.Logger, ) error { - logger.Info("finding maximum MTU, this can take up to 4 seconds") + logger.Info("finding maximum MTU, this can take up to 6 seconds") vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface) if err != nil { @@ -155,18 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, originalMTU := link.MTU - // Note: no point testing for an MTU of 1500, it will never work due to the VPN - // protocol overhead, so start lower than 1500 according to the protocol used. - const physicalLinkMTU uint32 = 1500 - vpnLinkMTU := physicalLinkMTU - switch vpnType { - case "wireguard": - vpnLinkMTU -= 60 // Wireguard overhead - case "openvpn": - vpnLinkMTU -= 41 // OpenVPN overhead - default: - return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType) - } + vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, vpnGatewayIP) // Setting the VPN link MTU to 1500 might interrupt the connection until // the new MTU is set again, but this is necessary to find the highest valid MTU. @@ -178,16 +184,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, } const pingTimeout = time.Second - vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger) - switch { - case err == nil: - logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) - case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted): + vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs, + vpnLinkMTU, pingTimeout, logger) + if err != nil { vpnLinkMTU = originalMTU logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)", vpnInterface, originalMTU, err) - default: - return fmt.Errorf("path MTU discovering: %w", err) + } else { + logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) } err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU) @@ -195,5 +199,33 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) } + err = setTCPMSSOnVPNRoute(vpnInterface, vpnLinkMTU, routing, netlinker) + if err != nil { + return fmt.Errorf("setting safe TCP MSS for MTU %d: %w", vpnLinkMTU, err) + } + + return nil +} + +func setTCPMSSOnVPNRoute(vpnIntf string, mtu uint32, + routing Routing, netlinker NetLinker, +) error { + route, err := routing.VPNRoute(vpnIntf) + if err != nil { + return fmt.Errorf("getting VPN route: %w", err) + } + + ipHeaderLength := pconstants.IPv4HeaderLength + if route.Dst.Addr().Is6() { + ipHeaderLength = pconstants.IPv6HeaderLength + } + const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps + overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead + mss := mtu - overhead + route.AdvMSS = mss + err = netlinker.RouteReplace(route) + if err != nil { + return fmt.Errorf("replacing VPN route with MSS changed to %d: %w", mss, err) + } return nil }