mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
Path MTU discovery fixes and improvements (#3109)
- Existing option `WIREGUARD_MTU` , if set, disables PMTUD and is used - New option `PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8` and `PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443` - ICMP PMTUD now targets external-by-default IP addresses - New TCP PMTUD (binary search only) as a second MTU confirmation and fallback mechanism. - Force set TCP MSS to MTU - IP header - TCP base header - "magic 20 bytes" 🎆 - Fix #3108
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
FROM ghcr.io/qdm12/godevcontainer:v0.21-alpine
|
||||||
RUN apk add wireguard-tools htop openssl
|
RUN apk add wireguard-tools htop openssl tcpdump
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ linters:
|
|||||||
- "^disabled$"
|
- "^disabled$"
|
||||||
# Firewall and routing strings
|
# Firewall and routing strings
|
||||||
- "^(ACCEPT|DROP)$"
|
- "^(ACCEPT|DROP)$"
|
||||||
|
- "^--append$"
|
||||||
- "^--delete$"
|
- "^--delete$"
|
||||||
- "^all$"
|
- "^all$"
|
||||||
- "^(tcp|udp)$"
|
- "^(tcp|udp)$"
|
||||||
|
|||||||
+4
-1
@@ -110,8 +110,11 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||||
WIREGUARD_ADDRESSES= \
|
WIREGUARD_ADDRESSES= \
|
||||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||||
WIREGUARD_MTU=1320 \
|
WIREGUARD_MTU= \
|
||||||
WIREGUARD_IMPLEMENTATION=auto \
|
WIREGUARD_IMPLEMENTATION=auto \
|
||||||
|
# PMTUD
|
||||||
|
PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \
|
||||||
|
PMTUD_TCP_ADDRESSES=1.1.1.1:443,8.8.8.8:443 \
|
||||||
# VPN server filtering
|
# VPN server filtering
|
||||||
SERVER_REGIONS= \
|
SERVER_REGIONS= \
|
||||||
SERVER_COUNTRIES= \
|
SERVER_COUNTRIES= \
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gosettings"
|
||||||
|
"github.com/qdm12/gosettings/reader"
|
||||||
|
"github.com/qdm12/gotree"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PMTUD contains settings to configure Path MTU Discovery.
|
||||||
|
type PMTUD struct {
|
||||||
|
// ICMPAddresses is the redundancy list of addresses to use
|
||||||
|
// for ICMP path MTU discovery. Each address MUST handle ICMP
|
||||||
|
// packets for PMTUD to work.
|
||||||
|
// It cannot be nil in the internal state.
|
||||||
|
ICMPAddresses []netip.Addr `json:"icmp_addresses"`
|
||||||
|
// TCPAddresses is the redundancy list of addresses to use
|
||||||
|
// for TCP path MTU discovery. Each address MUST have a listening
|
||||||
|
// TCP server on the port specified.
|
||||||
|
// It cannot be nil in the internal state.
|
||||||
|
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||||
|
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Validate validates PMTUD settings.
|
||||||
|
func (p PMTUD) validate() (err error) {
|
||||||
|
for i, addr := range p.ICMPAddresses {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, addr := range p.TCPAddresses {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PMTUD) copy() (copied PMTUD) {
|
||||||
|
return PMTUD{
|
||||||
|
ICMPAddresses: gosettings.CopySlice(p.ICMPAddresses),
|
||||||
|
TCPAddresses: gosettings.CopySlice(p.TCPAddresses),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PMTUD) overrideWith(other PMTUD) {
|
||||||
|
p.ICMPAddresses = gosettings.OverrideWithSlice(p.ICMPAddresses, other.ICMPAddresses)
|
||||||
|
p.TCPAddresses = gosettings.OverrideWithSlice(p.TCPAddresses, other.TCPAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PMTUD) setDefaults() {
|
||||||
|
defaultICMPAddresses := []netip.Addr{
|
||||||
|
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
|
||||||
|
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
|
||||||
|
}
|
||||||
|
p.ICMPAddresses = gosettings.DefaultSlice(p.ICMPAddresses, defaultICMPAddresses)
|
||||||
|
|
||||||
|
const tlsPort = 443
|
||||||
|
defaultTCPAddresses := []netip.AddrPort{
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), tlsPort),
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), tlsPort),
|
||||||
|
}
|
||||||
|
p.TCPAddresses = gosettings.DefaultSlice(p.TCPAddresses, defaultTCPAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PMTUD) String() string {
|
||||||
|
return p.toLinesNode().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PMTUD) toLinesNode() (node *gotree.Node) {
|
||||||
|
node = gotree.New("Path MTU discovery:")
|
||||||
|
|
||||||
|
addrs := make([]string, len(p.ICMPAddresses))
|
||||||
|
for i, addr := range p.ICMPAddresses {
|
||||||
|
addrs[i] = addr.String()
|
||||||
|
}
|
||||||
|
node.Appendf("ICMP addresses: %s", strings.Join(addrs, ", "))
|
||||||
|
|
||||||
|
addrs = make([]string, len(p.TCPAddresses))
|
||||||
|
for i, addr := range p.TCPAddresses {
|
||||||
|
addrs[i] = addr.String()
|
||||||
|
}
|
||||||
|
node.Appendf("TCP addresses: %s", strings.Join(addrs, ", "))
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PMTUD) read(r *reader.Reader) (err error) {
|
||||||
|
p.ICMPAddresses, err = r.CSVNetipAddresses("PMTUD_ICMP_ADDRESSES")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.TCPAddresses, err = r.CSVNetipAddrPorts("PMTUD_TCP_ADDRESSES")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -29,14 +29,17 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| | └── OpenVPN server selection settings:
|
| | └── OpenVPN server selection settings:
|
||||||
| | ├── Protocol: UDP
|
| | ├── Protocol: UDP
|
||||||
| | └── Private Internet Access encryption preset: strong
|
| | └── Private Internet Access encryption preset: strong
|
||||||
| └── OpenVPN settings:
|
| ├── OpenVPN settings:
|
||||||
| ├── OpenVPN version: 2.6
|
| | ├── OpenVPN version: 2.6
|
||||||
| ├── User: [not set]
|
| | ├── User: [not set]
|
||||||
| ├── Password: [not set]
|
| | ├── Password: [not set]
|
||||||
| ├── Private Internet Access encryption preset: strong
|
| | ├── Private Internet Access encryption preset: strong
|
||||||
| ├── Network interface: tun0
|
| | ├── Network interface: tun0
|
||||||
| ├── Run OpenVPN as: root
|
| | ├── Run OpenVPN as: root
|
||||||
| └── Verbosity level: 1
|
| | └── Verbosity level: 1
|
||||||
|
| └── Path MTU discovery:
|
||||||
|
| ├── ICMP addresses: 1.1.1.1, 8.8.8.8
|
||||||
|
| └── TCP addresses: 1.1.1.1:443, 8.8.8.8:443
|
||||||
├── DNS settings:
|
├── DNS settings:
|
||||||
| ├── Keep existing nameserver(s): no
|
| ├── Keep existing nameserver(s): no
|
||||||
| ├── DNS server address to use: 127.0.0.1
|
| ├── DNS server address to use: 127.0.0.1
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type VPN struct {
|
|||||||
Provider Provider `json:"provider"`
|
Provider Provider `json:"provider"`
|
||||||
OpenVPN OpenVPN `json:"openvpn"`
|
OpenVPN OpenVPN `json:"openvpn"`
|
||||||
Wireguard Wireguard `json:"wireguard"`
|
Wireguard Wireguard `json:"wireguard"`
|
||||||
|
PMTUD PMTUD `json:"pmtud"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
@@ -45,6 +46,11 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = v.PMTUD.validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("PMTUD settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +60,7 @@ func (v *VPN) Copy() (copied VPN) {
|
|||||||
Provider: v.Provider.copy(),
|
Provider: v.Provider.copy(),
|
||||||
OpenVPN: v.OpenVPN.copy(),
|
OpenVPN: v.OpenVPN.copy(),
|
||||||
Wireguard: v.Wireguard.copy(),
|
Wireguard: v.Wireguard.copy(),
|
||||||
|
PMTUD: v.PMTUD.copy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +69,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
|||||||
v.Provider.overrideWith(other.Provider)
|
v.Provider.overrideWith(other.Provider)
|
||||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||||
v.Wireguard.overrideWith(other.Wireguard)
|
v.Wireguard.overrideWith(other.Wireguard)
|
||||||
|
v.PMTUD.overrideWith(other.PMTUD)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *VPN) setDefaults() {
|
func (v *VPN) setDefaults() {
|
||||||
@@ -69,6 +77,7 @@ func (v *VPN) setDefaults() {
|
|||||||
v.Provider.setDefaults()
|
v.Provider.setDefaults()
|
||||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||||
v.Wireguard.setDefaults(v.Provider.Name)
|
v.Wireguard.setDefaults(v.Provider.Name)
|
||||||
|
v.PMTUD.setDefaults()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v VPN) String() string {
|
func (v VPN) String() string {
|
||||||
@@ -85,6 +94,7 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
|||||||
} else {
|
} else {
|
||||||
node.AppendNode(v.Wireguard.toLinesNode())
|
node.AppendNode(v.Wireguard.toLinesNode())
|
||||||
}
|
}
|
||||||
|
node.AppendNode(v.PMTUD.toLinesNode())
|
||||||
|
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
@@ -107,5 +117,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
|||||||
return fmt.Errorf("wireguard: %w", err)
|
return fmt.Errorf("wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = v.PMTUD.read(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("PMTUD: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,15 +38,9 @@ type Wireguard struct {
|
|||||||
Interface string `json:"interface"`
|
Interface string `json:"interface"`
|
||||||
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||||
// It cannot be zero in the internal state, and defaults to
|
// It cannot be nil in the internal state, and defaults to
|
||||||
// 1320. Note it is not the wireguard-go MTU default of 1420
|
// 0 indicating to use PMTUD.
|
||||||
// because this impacts bandwidth a lot on some VPN providers,
|
MTU *uint32 `json:"mtu"`
|
||||||
// see https://github.com/qdm12/gluetun/issues/1650.
|
|
||||||
// It has been lowered to 1320 following quite a bit of
|
|
||||||
// investigation in the issue:
|
|
||||||
// https://github.com/qdm12/gluetun/issues/2533.
|
|
||||||
// Note this should now be replaced with the PMTUD feature.
|
|
||||||
MTU uint32 `json:"mtu"`
|
|
||||||
// Implementation is the Wireguard implementation to use.
|
// Implementation is the Wireguard implementation to use.
|
||||||
// It can be "auto", "userspace" or "kernelspace".
|
// It can be "auto", "userspace" or "kernelspace".
|
||||||
// It defaults to "auto" and cannot be the empty string
|
// It defaults to "auto" and cannot be the empty string
|
||||||
@@ -195,8 +189,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
|||||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||||
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||||
const defaultMTU = 1320
|
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
|
||||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,7 +225,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
if *w.MTU == 0 {
|
||||||
|
interfaceNode.Append("MTU: use path MTU discovery")
|
||||||
|
} else {
|
||||||
|
interfaceNode.Appendf("MTU: %d", *w.MTU)
|
||||||
|
}
|
||||||
|
|
||||||
if w.Implementation != "auto" {
|
if w.Implementation != "auto" {
|
||||||
node.Appendf("Implementation: %s", w.Implementation)
|
node.Appendf("Implementation: %s", w.Implementation)
|
||||||
@@ -273,11 +270,9 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
|
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if mtuPtr != nil {
|
|
||||||
w.MTU = *mtuPtr
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,17 +29,16 @@ func appendOrDelete(remove bool) string {
|
|||||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||||
// append rule.
|
// append rule.
|
||||||
func flipRule(rule string) string {
|
func flipRule(rule string) string {
|
||||||
switch {
|
fields := strings.Fields(rule)
|
||||||
case strings.HasPrefix(rule, "-A"):
|
for i, field := range fields {
|
||||||
return strings.Replace(rule, "-A", "-D", 1)
|
switch field {
|
||||||
case strings.HasPrefix(rule, "--append"):
|
case "-A", "--append":
|
||||||
return strings.Replace(rule, "--append", "-D", 1)
|
fields[i] = "--delete"
|
||||||
case strings.HasPrefix(rule, "-D"):
|
case "-D", "--delete":
|
||||||
return strings.Replace(rule, "-D", "-A", 1)
|
fields[i] = "--append"
|
||||||
case strings.HasPrefix(rule, "--delete"):
|
}
|
||||||
return strings.Replace(rule, "--delete", "-A", 1)
|
|
||||||
}
|
}
|
||||||
return rule
|
return strings.Join(fields, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version obtains the version of the installed iptables.
|
// Version obtains the version of the installed iptables.
|
||||||
@@ -86,10 +85,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||||
return c.runMixedIptablesInstructions(ctx, []string{
|
tables := []string{"filter"}
|
||||||
"--flush", // flush all chains
|
for _, table := range tables {
|
||||||
"--delete-chain", // delete all chains
|
return c.runMixedIptablesInstructions(ctx, []string{
|
||||||
})
|
"-t " + table + " --flush", // flush all chains
|
||||||
|
"-t " + table + " --delete-chain", // delete all chains
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Route struct {
|
|||||||
Type uint8
|
Type uint8
|
||||||
Scope uint8
|
Scope uint8
|
||||||
Proto uint8
|
Proto uint8
|
||||||
|
AdvMSS uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
||||||
@@ -35,6 +36,9 @@ func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
|
|||||||
r.Type = message.Type
|
r.Type = message.Type
|
||||||
r.Scope = message.Scope
|
r.Scope = message.Scope
|
||||||
r.Proto = message.Protocol
|
r.Proto = message.Protocol
|
||||||
|
if metrics := message.Attributes.Metrics; metrics != nil {
|
||||||
|
r.AdvMSS = metrics.AdvMSS
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Route) message() *rtnetlink.RouteMessage {
|
func (r Route) message() *rtnetlink.RouteMessage {
|
||||||
@@ -58,7 +62,6 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
|||||||
Protocol: r.Proto,
|
Protocol: r.Proto,
|
||||||
Attributes: rtnetlink.RouteAttributes{
|
Attributes: rtnetlink.RouteAttributes{
|
||||||
OutIface: r.LinkIndex,
|
OutIface: r.LinkIndex,
|
||||||
Dst: *dst, // there should always be a dst for routes
|
|
||||||
Gateway: netipAddrToNetIP(r.Gw),
|
Gateway: netipAddrToNetIP(r.Gw),
|
||||||
Priority: r.Priority,
|
Priority: r.Priority,
|
||||||
Table: extendedTable,
|
Table: extendedTable,
|
||||||
@@ -67,6 +70,15 @@ func (r Route) message() *rtnetlink.RouteMessage {
|
|||||||
if src != nil { // src is optional
|
if src != nil { // src is optional
|
||||||
message.Attributes.Src = *src
|
message.Attributes.Src = *src
|
||||||
}
|
}
|
||||||
|
if dst != nil {
|
||||||
|
message.Attributes.Dst = *dst
|
||||||
|
}
|
||||||
|
if r.AdvMSS != 0 {
|
||||||
|
if message.Attributes.Metrics == nil {
|
||||||
|
message.Attributes.Metrics = &rtnetlink.RouteMetrics{}
|
||||||
|
}
|
||||||
|
message.Attributes.Metrics.AdvMSS = r.AdvMSS
|
||||||
|
}
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package constants
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxEthernetFrameSize uint32 = 1500
|
||||||
|
// MinIPv4MTU is defined according to
|
||||||
|
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||||
|
MinIPv4MTU uint32 = 68
|
||||||
|
MinIPv6MTU uint32 = 1280
|
||||||
|
|
||||||
|
IPv4HeaderLength uint32 = 20
|
||||||
|
IPv6HeaderLength uint32 = 40
|
||||||
|
UDPHeaderLength uint32 = 8
|
||||||
|
// BaseTCPHeaderLength is the TCP header length without options,
|
||||||
|
// which is the minimum TCP header length.
|
||||||
|
BaseTCPHeaderLength uint32 = 20
|
||||||
|
// MaxTCPHeaderLength is the TCP header length with the maximum options length of 40 bytes.
|
||||||
|
// Note this is a hard maximum because of the 4-bit data offset field in the TCP header (15x4=60).
|
||||||
|
MaxTCPHeaderLength uint32 = 60
|
||||||
|
WireguardHeaderLength uint32 = 32
|
||||||
|
OpenVPNHeaderMaxLength uint32 = 1 + // opcode
|
||||||
|
8 + // session id
|
||||||
|
4 + // packet id
|
||||||
|
28 // max possible auth tag/iv
|
||||||
|
)
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pmtud
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
|
||||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
|
||||||
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
|
||||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
|
||||||
)
|
|
||||||
|
|
||||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
|
||||||
switch {
|
|
||||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
|
||||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
|
||||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
|
||||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
|
||||||
case timedCtx.Err() != nil:
|
|
||||||
err = timedCtx.Err()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -9,17 +9,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
ErrNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
ErrNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case mtu < minMTU:
|
case mtu < minMTU:
|
||||||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
return fmt.Errorf("%w: %d", ErrNextHopMTUTooLow, mtu)
|
||||||
case mtu > physicalLinkMTU:
|
case mtu > physicalLinkMTU:
|
||||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
ErrNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -34,13 +34,13 @@ func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
|||||||
}
|
}
|
||||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
return false, fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||||
}
|
}
|
||||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
return inboundBody.ID == outboundBody.ID, nil
|
return inboundBody.ID == outboundBody.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
var ErrIDMismatch = errors.New("ICMP id mismatch")
|
||||||
|
|
||||||
func checkEchoReply(icmpProtocol int, received []byte,
|
func checkEchoReply(icmpProtocol int, received []byte,
|
||||||
outboundMessage *icmp.Message, truncatedBody bool,
|
outboundMessage *icmp.Message, truncatedBody bool,
|
||||||
@@ -51,12 +51,12 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
|||||||
}
|
}
|
||||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
return fmt.Errorf("%w: %T", ErrBodyUnsupported, inboundMessage.Body)
|
||||||
}
|
}
|
||||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||||
if inboundBody.ID != outboundBody.ID {
|
if inboundBody.ID != outboundBody.ID {
|
||||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
ErrIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||||
}
|
}
|
||||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -65,19 +65,19 @@ func checkEchoReply(icmpProtocol int, received []byte,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
var ErrEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||||
|
|
||||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||||
if len(received) > len(sent) {
|
if len(received) > len(sent) {
|
||||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||||
ErrICMPEchoDataMismatch, len(sent), len(received))
|
ErrEchoDataMismatch, len(sent), len(received))
|
||||||
}
|
}
|
||||||
if receivedTruncated {
|
if receivedTruncated {
|
||||||
sent = sent[:len(received)]
|
sent = sent[:len(received)]
|
||||||
}
|
}
|
||||||
if !bytes.Equal(received, sent) {
|
if !bytes.Equal(received, sent) {
|
||||||
return fmt.Errorf("%w: sent %x and received %x",
|
return fmt.Errorf("%w: sent %x and received %x",
|
||||||
ErrICMPEchoDataMismatch, sent, received)
|
ErrEchoDataMismatch, sent, received)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build !linux && !windows
|
//go:build !linux && !windows
|
||||||
|
|
||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
// setDontFragment for platforms other than Linux and Windows
|
// setDontFragment for platforms other than Linux and Windows
|
||||||
// is not implemented, so we just return assuming the don't
|
// is not implemented, so we just return assuming the don't
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
//go:build windows
|
package icmp
|
||||||
|
|
||||||
package pmtud
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package icmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotPermitted = errors.New("ICMP not permitted")
|
||||||
|
ErrDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||||
|
ErrCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||||
|
ErrBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||||
|
ErrMTUNotFound = errors.New("MTU not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||||
|
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||||
|
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||||
|
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||||
|
case timedCtx.Err() != nil:
|
||||||
|
err = timedCtx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package icmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PathMTUDiscover discovers the path MTU to the given IP address
|
||||||
|
// using ICMP.
|
||||||
|
// It first tries to get the next hop MTU using ICMP messages.
|
||||||
|
// If that fails, it falls back to sending echo requests with
|
||||||
|
// different packet sizes to find the maximum MTU.
|
||||||
|
// The function returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||||
|
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||||
|
physicalLinkMTU uint32, timeout time.Duration, logger Logger,
|
||||||
|
) (mtu uint32, err error) {
|
||||||
|
if ip.Is4() {
|
||||||
|
logger.Debug("finding IPv4 next hop MTU")
|
||||||
|
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return mtu, nil
|
||||||
|
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrCommunicationAdministrativelyProhibited): // blackhole
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||||
|
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, timeout, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return mtu, nil
|
||||||
|
case errors.Is(err, net.ErrClosed): // blackhole
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back method: send echo requests with different packet
|
||||||
|
// sizes and check which ones succeed to find the maximum MTU.
|
||||||
|
logger.Debug("falling back to sending different sized echo packets")
|
||||||
|
minMTU := constants.MinIPv4MTU
|
||||||
|
if ip.Is6() {
|
||||||
|
minMTU = constants.MinIPv6MTU
|
||||||
|
}
|
||||||
|
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, timeout, logger)
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package icmp
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(msg string)
|
||||||
|
Debugf(msg string, args ...any)
|
||||||
|
Warnf(msg string, args ...any)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -11,14 +11,13 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
icmpv4Protocol = 1
|
||||||
minIPv4MTU uint32 = 68
|
|
||||||
icmpv4Protocol int = 1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||||
@@ -38,7 +37,7 @@ func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
|||||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -83,7 +82,9 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
|
|
||||||
buffer := make([]byte, physicalLinkMTU)
|
buffer := make([]byte, physicalLinkMTU)
|
||||||
|
|
||||||
for { // for loop in case we read an echo reply for another ICMP request
|
// for loop in case we read an ICMP message from another ICMP request
|
||||||
|
// or TCP/UDP traffic triggering an ICMP response.
|
||||||
|
for {
|
||||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||||
// must be large enough to read the entire reply packet. See:
|
// must be large enough to read the entire reply packet. See:
|
||||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
@@ -108,24 +109,27 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
switch typedBody := inboundMessage.Body.(type) {
|
switch typedBody := inboundMessage.Body.(type) {
|
||||||
case *icmp.DstUnreach:
|
case *icmp.DstUnreach:
|
||||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||||
|
const portUnreachable = 3
|
||||||
const communicationAdministrativelyProhibitedCode = 13
|
const communicationAdministrativelyProhibitedCode = 13
|
||||||
switch inboundMessage.Code {
|
switch inboundMessage.Code {
|
||||||
case fragmentationRequiredAndDFFlagSetCode:
|
case fragmentationRequiredAndDFFlagSetCode:
|
||||||
|
case portUnreachable: // triggered by TCP or UDP from applications
|
||||||
|
continue // ignore and wait for the next message
|
||||||
case communicationAdministrativelyProhibitedCode:
|
case communicationAdministrativelyProhibitedCode:
|
||||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||||
ErrICMPDestinationUnreachable,
|
ErrDestinationUnreachable,
|
||||||
ErrICMPCommunicationAdministrativelyProhibited,
|
ErrCommunicationAdministrativelyProhibited,
|
||||||
inboundMessage.Code)
|
inboundMessage.Code)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: code %d",
|
return 0, fmt.Errorf("%w: code %d",
|
||||||
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
ErrDestinationUnreachable, inboundMessage.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||||
// Note: the go library does not handle this NextHopMTU section.
|
// Note: the go library does not handle this NextHopMTU section.
|
||||||
nextHopMTU := packetBytes[6:8]
|
nextHopMTU := packetBytes[6:8]
|
||||||
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
|
||||||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
err = checkMTU(mtu, constants.MinIPv4MTU, physicalLinkMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||||
}
|
}
|
||||||
@@ -153,7 +157,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
minIPv6MTU = 1280
|
|
||||||
icmpv6Protocol = 58
|
icmpv6Protocol = 58
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
|||||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -85,7 +85,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
case *icmp.PacketTooBig:
|
case *icmp.PacketTooBig:
|
||||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||||
mtu = uint32(typedBody.MTU) //nolint:gosec
|
mtu = uint32(typedBody.MTU) //nolint:gosec
|
||||||
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
err = checkMTU(mtu, constants.MinIPv6MTU, physicalLinkMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||||
} else if idMatch {
|
} else if idMatch {
|
||||||
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
return 0, fmt.Errorf("%w", ErrDestinationUnreachable)
|
||||||
}
|
}
|
||||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||||
continue
|
continue
|
||||||
@@ -116,7 +116,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
return 0, fmt.Errorf("%w: %T", ErrBodyUnsupported, typedBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package icmp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
cryptorand "crypto/rand"
|
cryptorand "crypto/rand"
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
package icmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type icmpTestUnit struct {
|
||||||
|
mtu uint32
|
||||||
|
echoID uint16
|
||||||
|
sentBytes int
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||||
|
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
||||||
|
logger Logger,
|
||||||
|
) (maxMTU uint32, err error) {
|
||||||
|
var ipVersion string
|
||||||
|
var conn net.PacketConn
|
||||||
|
if ip.Is4() {
|
||||||
|
ipVersion = "v4"
|
||||||
|
conn, err = listenICMPv4(ctx)
|
||||||
|
} else {
|
||||||
|
ipVersion = "v6"
|
||||||
|
conn, err = listenICMPv6(ctx)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||||
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||||
|
return minMTU, nil
|
||||||
|
}
|
||||||
|
logger.Debugf("ICMP testing the following MTUs: %v", mtusToTest)
|
||||||
|
|
||||||
|
tests := make([]icmpTestUnit, len(mtusToTest))
|
||||||
|
for i := range mtusToTest {
|
||||||
|
tests[i] = icmpTestUnit{mtu: mtusToTest[i]}
|
||||||
|
}
|
||||||
|
|
||||||
|
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-timedCtx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||||
|
tests[i].echoID = id
|
||||||
|
|
||||||
|
encodedMessage, err := message.Marshal(nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
tests[i].sentBytes = len(encodedMessage)
|
||||||
|
|
||||||
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||||
|
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = collectReplies(conn, ipVersion, tests, logger)
|
||||||
|
switch {
|
||||||
|
case err == nil: // max possible MTU is working
|
||||||
|
return tests[len(tests)-1].mtu, nil
|
||||||
|
case err != nil && errors.Is(err, net.ErrClosed):
|
||||||
|
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||||
|
// so find the highest MTU which worked.
|
||||||
|
// Note we start from index len(tests) - 2 since the max MTU
|
||||||
|
// cannot be working if we had a timeout.
|
||||||
|
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||||
|
if tests[i].ok {
|
||||||
|
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||||
|
pingTimeout, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All MTUs failed.
|
||||||
|
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||||
|
case err != nil:
|
||||||
|
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||||
|
// create huge buffers which we don't really want to support anyway.
|
||||||
|
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||||
|
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||||
|
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||||
|
// match eventual Jumbo frames. More information at:
|
||||||
|
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||||
|
const maxPossibleMTU = 9196
|
||||||
|
|
||||||
|
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||||
|
tests []icmpTestUnit, logger Logger,
|
||||||
|
) (err error) {
|
||||||
|
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||||
|
for i, test := range tests {
|
||||||
|
echoIDToTestIndex[test.echoID] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := make([]byte, maxPossibleMTU)
|
||||||
|
|
||||||
|
idsFound := 0
|
||||||
|
for idsFound < len(tests) {
|
||||||
|
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||||
|
// must be large enough to read the entire reply packet. See:
|
||||||
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
|
}
|
||||||
|
packetBytes := buffer[:bytesRead]
|
||||||
|
|
||||||
|
ipPacketLength := len(packetBytes)
|
||||||
|
|
||||||
|
var icmpProtocol int
|
||||||
|
switch ipVersion {
|
||||||
|
case "v4":
|
||||||
|
icmpProtocol = icmpv4Protocol
|
||||||
|
case "v6":
|
||||||
|
icmpProtocol = icmpv6Protocol
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the ICMP message
|
||||||
|
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||||
|
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
echoBody, ok := message.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%w: %T", ErrBodyUnsupported, message.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uint16(echoBody.ID) //nolint:gosec
|
||||||
|
testIndex, testing := echoIDToTestIndex[id]
|
||||||
|
if !testing { // not an id we expected so ignore it
|
||||||
|
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||||
|
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idsFound++
|
||||||
|
sentBytes := tests[testIndex].sentBytes
|
||||||
|
|
||||||
|
// echo reply should be at most the number of bytes sent,
|
||||||
|
// and can be lower, more precisely 556 bytes, in case
|
||||||
|
// the host we are reaching wants to stay out of trouble
|
||||||
|
// and ensure its echo reply goes through without
|
||||||
|
// fragmentation, see the following page:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||||
|
const conservativeReplyLength = 556
|
||||||
|
truncated := ipPacketLength < sentBytes &&
|
||||||
|
ipPacketLength == conservativeReplyLength
|
||||||
|
// Check the packet size is the same if the reply is not truncated
|
||||||
|
if !truncated && sentBytes != ipPacketLength {
|
||||||
|
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||||
|
ErrEchoDataMismatch, sentBytes, ipPacketLength)
|
||||||
|
}
|
||||||
|
// Truncated reply or matching reply size
|
||||||
|
tests[testIndex].ok = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HeaderV4(srcIP, dstIP netip.Addr, payloadLength uint32) []byte {
|
||||||
|
ipHeader := make([]byte, constants.IPv4HeaderLength)
|
||||||
|
const version byte = 4
|
||||||
|
const headerLength byte = 20 / 4 // in 32-bit words
|
||||||
|
ipHeader[0] = (version << 4) | headerLength //nolint:mnd
|
||||||
|
ipHeader[1] = 0 // type of Service
|
||||||
|
putUint16(ipHeader[2:], uint16(constants.IPv4HeaderLength+payloadLength)) //nolint:gosec
|
||||||
|
ipHeader[4], ipHeader[5] = 0, 0 // identification
|
||||||
|
const flagsAndOffset uint16 = 0x4000 // DF bit set
|
||||||
|
putUint16(ipHeader[6:], flagsAndOffset)
|
||||||
|
ipHeader[8] = 64 // ttl
|
||||||
|
ipHeader[9] = syscall.IPPROTO_TCP
|
||||||
|
srcIPBytes := srcIP.As4()
|
||||||
|
copy(ipHeader[12:16], srcIPBytes[:])
|
||||||
|
dstIPBytes := dstIP.As4()
|
||||||
|
copy(ipHeader[16:20], dstIPBytes[:])
|
||||||
|
|
||||||
|
checksum := ipChecksum(ipHeader)
|
||||||
|
ipHeader[10] = byte(checksum >> 8) //nolint:mnd
|
||||||
|
ipHeader[11] = byte(checksum & 0xff) //nolint:mnd
|
||||||
|
|
||||||
|
return ipHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipChecksum calculates the checksum for the IP header.
|
||||||
|
//
|
||||||
|
//nolint:mnd
|
||||||
|
func ipChecksum(header []byte) uint16 {
|
||||||
|
sum := uint32(0)
|
||||||
|
for i := 0; i < len(header)-1; i += 2 {
|
||||||
|
sum += uint32(header[i])<<8 + uint32(header[i+1])
|
||||||
|
}
|
||||||
|
if len(header)%2 != 0 {
|
||||||
|
sum += uint32(header[len(header)-1]) << 8
|
||||||
|
}
|
||||||
|
for (sum >> 16) > 0 {
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum) //nolint:gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderV6 makes an IPv6 header.
|
||||||
|
// payloadLen is the length of the payload following the header.
|
||||||
|
// nextHeader can be byte([syscall.IPPROTO_TCP]) for example.
|
||||||
|
func HeaderV6(srcIP, dstIP netip.Addr,
|
||||||
|
payloadLen uint16, nextHeader byte,
|
||||||
|
) []byte {
|
||||||
|
ipv6Header := make([]byte, constants.IPv6HeaderLength)
|
||||||
|
ipv6Header[0] = 0x60 // version (4 bits) | traffic Class (4 bits)
|
||||||
|
ipv6Header[1] = 0x00 // traffic Class (4 bits) | flow label (4 bits)
|
||||||
|
|
||||||
|
// Flow Label (remaining 16 bits)
|
||||||
|
ipv6Header[2] = 0x00
|
||||||
|
ipv6Header[3] = 0x00
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(ipv6Header[4:], payloadLen)
|
||||||
|
ipv6Header[6] = nextHeader
|
||||||
|
const hopLimit = 64
|
||||||
|
ipv6Header[7] = hopLimit
|
||||||
|
copy(ipv6Header[8:24], srcIP.AsSlice())
|
||||||
|
copy(ipv6Header[24:40], dstIP.AsSlice())
|
||||||
|
return ipv6Header
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
func putUint16(b []byte, v uint16) {
|
||||||
|
binary.NativeEndian.PutUint16(b, v)
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !darwin
|
||||||
|
|
||||||
|
package ip
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
func putUint16(b []byte, v uint16) {
|
||||||
|
binary.BigEndian.PutUint16(b, v)
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package ip
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func SetIPv4HeaderIncluded(fd int) error {
|
||||||
|
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !windows && !darwin
|
||||||
|
|
||||||
|
package ip
|
||||||
|
|
||||||
|
func SetIPv4HeaderIncluded(fd int) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetIPv4HeaderIncluded(handle syscall.Handle) error {
|
||||||
|
const ipHdrIncluded = windows.IP_HDRINCL
|
||||||
|
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IP, ipHdrIncluded, 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
func SetIPv6HeaderIncluded(fd int) error {
|
||||||
|
panic("darwin does not allow an application to build IPv6 headers")
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func SetIPv6HeaderIncluded(fd int) error {
|
||||||
|
const ipv6HdrIncluded = 36 // IPV6_HDRINCL
|
||||||
|
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, ipv6HdrIncluded, 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !windows && !darwin
|
||||||
|
|
||||||
|
package ip
|
||||||
|
|
||||||
|
func SetIPv6HeaderIncluded(fd int) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func SetIPv6HeaderIncluded(fd syscall.Handle) error {
|
||||||
|
panic("windows does not allow an application to build IPv6 headers")
|
||||||
|
}
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/jsimonetti/rtnetlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SrcAddr determines the appropriate source IP address to use when sending a packet to the
|
||||||
|
// specified destination. It also reserves an ephemeral source port for the specified protocol
|
||||||
|
// to ensure that the port is not used by other processes. The cleanup function returned should
|
||||||
|
// be called to release the reserved port when done.
|
||||||
|
func SrcAddr(dst netip.AddrPort, proto int) (src netip.AddrPort, cleanup func(), err error) {
|
||||||
|
srcAddr, err := srcIP(dst.Addr())
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, nil, fmt.Errorf("finding source IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcPort, cleanup, err := srcPort(srcAddr, proto)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, nil, fmt.Errorf("reserving source port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrPortFrom(srcAddr, srcPort), cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errNoRoute = fmt.Errorf("no route to destination")
|
||||||
|
|
||||||
|
func srcIP(dst netip.Addr) (netip.Addr, error) {
|
||||||
|
conn, err := rtnetlink.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
family := uint8(syscall.AF_INET)
|
||||||
|
if dst.Is6() {
|
||||||
|
family = syscall.AF_INET6
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request route to destination
|
||||||
|
requestMessage := &rtnetlink.RouteMessage{
|
||||||
|
Family: family,
|
||||||
|
Attributes: rtnetlink.RouteAttributes{
|
||||||
|
Dst: dst.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
messages, err := conn.Route.Get(requestMessage)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", dst, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.Attributes.Src == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipv6 := message.Attributes.Src.To4() == nil
|
||||||
|
if ipv6 {
|
||||||
|
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
|
||||||
|
}
|
||||||
|
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}, fmt.Errorf("%w: in %d route(s)", errNoRoute, len(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
// srcPort reserves an ephemeral source port by opening a socket for the
|
||||||
|
// protocol specified and binds it to the provided source address.
|
||||||
|
// It doesn't actually listen on the port.
|
||||||
|
// The cleanup function returned should be called to release the port when done.
|
||||||
|
func srcPort(srcAddr netip.Addr, proto int) (srcPort uint16, cleanup func(), err error) {
|
||||||
|
family := syscall.AF_INET
|
||||||
|
if srcAddr.Is6() {
|
||||||
|
family = syscall.AF_INET6
|
||||||
|
}
|
||||||
|
|
||||||
|
fd, err := syscall.Socket(family, syscall.SOCK_STREAM, proto)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("creating reservation socket: %w", err)
|
||||||
|
}
|
||||||
|
cleanup = func() {
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind to port 0 to get an ephemeral port
|
||||||
|
const port = 0
|
||||||
|
var bindAddr syscall.Sockaddr
|
||||||
|
if srcAddr.Is4() {
|
||||||
|
bindAddr = &syscall.SockaddrInet4{
|
||||||
|
Port: port,
|
||||||
|
Addr: srcAddr.As4(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bindAddr = &syscall.SockaddrInet6{
|
||||||
|
Port: port,
|
||||||
|
Addr: srcAddr.As16(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = syscall.Bind(fd, bindAddr)
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
return 0, nil, fmt.Errorf("binding reservation socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sockAddr, err := syscall.Getsockname(fd)
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
return 0, nil, fmt.Errorf("getting bound socket name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch typedSockAddr := sockAddr.(type) {
|
||||||
|
case *syscall.SockaddrInet4:
|
||||||
|
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
|
||||||
|
case *syscall.SockaddrInet6:
|
||||||
|
srcPort = uint16(typedSockAddr.Port) //nolint:gosec
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unexpected sockaddr type: %T", typedSockAddr))
|
||||||
|
}
|
||||||
|
return srcPort, cleanup, nil
|
||||||
|
}
|
||||||
+38
-233
@@ -4,268 +4,73 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/icmp"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/tcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
// PathMTUDiscover discovers the maximum MTU using both ICMP and TCP.
|
||||||
|
// Multiple ICMP addresses and TCP addresses can be specified for redundancy.
|
||||||
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
// ICMP PMTUD is run first. If successful, the range of possible MTU values to
|
||||||
|
// check for TCP PMTUD is reduced to [maxMTU-150, maxMTU] where maxMTU is the
|
||||||
|
// maximum MTU found with ICMP PMTUD. Otherwise, TCP PMTUD is run with the
|
||||||
|
// whole range of possible MTU values up to the physical link MTU to check.
|
||||||
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||||
// If the pingTimeout is zero, it defaults to 1 second.
|
// If the pingTimeout is zero, it defaults to 1 second.
|
||||||
// If the logger is nil, a no-op logger is used.
|
// If the logger is nil, a no-op logger is used.
|
||||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
func PathMTUDiscover(ctx context.Context, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||||
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
|
physicalLinkMTU uint32, tryTimeout time.Duration, logger Logger) (
|
||||||
mtu uint32, err error,
|
mtu uint32, err error,
|
||||||
) {
|
) {
|
||||||
if physicalLinkMTU == 0 {
|
if physicalLinkMTU == 0 {
|
||||||
const ethernetStandardMTU = 1500
|
const ethernetStandardMTU = 1500
|
||||||
physicalLinkMTU = ethernetStandardMTU
|
physicalLinkMTU = ethernetStandardMTU
|
||||||
}
|
}
|
||||||
if pingTimeout == 0 {
|
if tryTimeout == 0 {
|
||||||
pingTimeout = time.Second
|
tryTimeout = time.Second
|
||||||
}
|
}
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = &noopLogger{}
|
logger = &noopLogger{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.Is4() {
|
// Try finding the MTU using ICMP
|
||||||
logger.Debug("finding IPv4 next hop MTU")
|
maxPossibleMTU := physicalLinkMTU
|
||||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
icmpSuccess := false
|
||||||
|
for _, icmpIP := range icmpAddrs {
|
||||||
|
mtu, err := icmp.PathMTUDiscover(ctx, icmpIP, physicalLinkMTU,
|
||||||
|
tryTimeout, logger)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
return mtu, nil
|
logger.Debugf("ICMP path MTU discovery against %s found maximum valid MTU %d", icmpIP, mtu)
|
||||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
icmpSuccess = true
|
||||||
|
maxPossibleMTU = mtu
|
||||||
|
case errors.Is(err, icmp.ErrNotPermitted), errors.Is(err, icmp.ErrMTUNotFound):
|
||||||
|
logger.Debugf("ICMP path MTU discovery failed: %s", err)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
return 0, fmt.Errorf("ICMP path MTU discovery: %w", err)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
|
||||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
return mtu, nil
|
|
||||||
case errors.Is(err, net.ErrClosed): // blackhole
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back method: send echo requests with different packet
|
for _, addrPort := range tcpAddrs {
|
||||||
// sizes and check which ones succeed to find the maximum MTU.
|
minMTU := constants.MinIPv4MTU
|
||||||
logger.Debug("falling back to sending different sized echo packets")
|
if addrPort.Addr().Is6() {
|
||||||
minMTU := minIPv4MTU
|
minMTU = constants.MinIPv6MTU
|
||||||
if ip.Is6() {
|
|
||||||
minMTU = minIPv6MTU
|
|
||||||
}
|
|
||||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
|
||||||
}
|
|
||||||
|
|
||||||
type pmtudTestUnit struct {
|
|
||||||
mtu uint32
|
|
||||||
echoID uint16
|
|
||||||
sentBytes int
|
|
||||||
ok bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
|
||||||
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
|
|
||||||
logger Logger,
|
|
||||||
) (maxMTU uint32, err error) {
|
|
||||||
var ipVersion string
|
|
||||||
var conn net.PacketConn
|
|
||||||
if ip.Is4() {
|
|
||||||
ipVersion = "v4"
|
|
||||||
conn, err = listenICMPv4(ctx)
|
|
||||||
} else {
|
|
||||||
ipVersion = "v6"
|
|
||||||
conn, err = listenICMPv6(ctx)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
|
||||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
if icmpSuccess {
|
||||||
}
|
const mtuMargin = 150
|
||||||
|
minMTU = max(maxPossibleMTU-mtuMargin, minMTU)
|
||||||
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
}
|
||||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
mtu, err = tcp.PathMTUDiscover(ctx, addrPort, minMTU, maxPossibleMTU, logger)
|
||||||
return minMTU, nil
|
|
||||||
}
|
|
||||||
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
|
||||||
|
|
||||||
tests := make([]pmtudTestUnit, len(mtusToTest))
|
|
||||||
for i := range mtusToTest {
|
|
||||||
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
|
||||||
}
|
|
||||||
|
|
||||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
<-timedCtx.Done()
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := range tests {
|
|
||||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
|
||||||
tests[i].echoID = id
|
|
||||||
|
|
||||||
encodedMessage, err := message.Marshal(nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
logger.Debugf("TCP path MTU discovery to %s failed: %s", addrPort, err)
|
||||||
}
|
|
||||||
tests[i].sentBytes = len(encodedMessage)
|
|
||||||
|
|
||||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
|
||||||
if err != nil {
|
|
||||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
|
||||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = collectReplies(conn, ipVersion, tests, logger)
|
|
||||||
switch {
|
|
||||||
case err == nil: // max possible MTU is working
|
|
||||||
return tests[len(tests)-1].mtu, nil
|
|
||||||
case err != nil && errors.Is(err, net.ErrClosed):
|
|
||||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
|
||||||
// so find the highest MTU which worked.
|
|
||||||
// Note we start from index len(tests) - 2 since the max MTU
|
|
||||||
// cannot be working if we had a timeout.
|
|
||||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
|
||||||
if tests[i].ok {
|
|
||||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
|
||||||
pingTimeout, logger)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All MTUs failed.
|
|
||||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
|
||||||
case err != nil:
|
|
||||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
|
||||||
default:
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the MTU slice of length 11 such that:
|
|
||||||
// - the first element is the minMTU
|
|
||||||
// - the last element is the maxMTU
|
|
||||||
// - elements in-between are separated as close to each other
|
|
||||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
|
||||||
// with a total search space of 1728 MTUs which is enough;
|
|
||||||
// to find it in 2 searches requires 37 parallel queries which
|
|
||||||
// could be blocked by firewalls.
|
|
||||||
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
|
||||||
const mtusLength = 11 // find the final MTU in 3 searches
|
|
||||||
diff := maxMTU - minMTU
|
|
||||||
switch {
|
|
||||||
case minMTU > maxMTU:
|
|
||||||
panic("minMTU > maxMTU")
|
|
||||||
case diff <= mtusLength:
|
|
||||||
mtus = make([]uint32, 0, diff)
|
|
||||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
|
||||||
mtus = append(mtus, mtu)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
step := float64(diff) / float64(mtusLength-1)
|
|
||||||
mtus = make([]uint32, 0, mtusLength)
|
|
||||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
|
||||||
mtus = append(mtus, uint32(math.Round(mtu)))
|
|
||||||
}
|
|
||||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
|
||||||
}
|
|
||||||
|
|
||||||
return mtus
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
|
||||||
tests []pmtudTestUnit, logger Logger,
|
|
||||||
) (err error) {
|
|
||||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
|
||||||
for i, test := range tests {
|
|
||||||
echoIDToTestIndex[test.echoID] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
|
||||||
// create huge buffers which we don't really want to support anyway.
|
|
||||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
|
||||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
|
||||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
|
||||||
// match eventual Jumbo frames. More information at:
|
|
||||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
|
||||||
const maxPossibleMTU = 9196
|
|
||||||
buffer := make([]byte, maxPossibleMTU)
|
|
||||||
|
|
||||||
idsFound := 0
|
|
||||||
for idsFound < len(tests) {
|
|
||||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
|
||||||
// must be large enough to read the entire reply packet. See:
|
|
||||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
|
||||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
|
||||||
}
|
|
||||||
packetBytes := buffer[:bytesRead]
|
|
||||||
|
|
||||||
ipPacketLength := len(packetBytes)
|
|
||||||
|
|
||||||
var icmpProtocol int
|
|
||||||
switch ipVersion {
|
|
||||||
case "v4":
|
|
||||||
icmpProtocol = icmpv4Protocol
|
|
||||||
case "v6":
|
|
||||||
icmpProtocol = icmpv6Protocol
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the ICMP message
|
|
||||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
|
||||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parsing message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
echoBody, ok := message.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
id := uint16(echoBody.ID) //nolint:gosec
|
|
||||||
testIndex, testing := echoIDToTestIndex[id]
|
|
||||||
if !testing { // not an id we expected so ignore it
|
|
||||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
|
||||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
idsFound++
|
logger.Debugf("TCP path MTU discovery to %s found maximum valid MTU %d", addrPort, mtu)
|
||||||
sentBytes := tests[testIndex].sentBytes
|
return mtu, nil
|
||||||
|
|
||||||
// echo reply should be at most the number of bytes sent,
|
|
||||||
// and can be lower, more precisely 556 bytes, in case
|
|
||||||
// the host we are reaching wants to stay out of trouble
|
|
||||||
// and ensure its echo reply goes through without
|
|
||||||
// fragmentation, see the following page:
|
|
||||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
|
||||||
const conservativeReplyLength = 556
|
|
||||||
truncated := ipPacketLength < sentBytes &&
|
|
||||||
ipPacketLength == conservativeReplyLength
|
|
||||||
// Check the packet size is the same if the reply is not truncated
|
|
||||||
if !truncated && sentBytes != ipPacketLength {
|
|
||||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
|
||||||
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
|
||||||
}
|
|
||||||
// Truncated reply or matching reply size
|
|
||||||
tests[testIndex].ok = true
|
|
||||||
}
|
}
|
||||||
return nil
|
return 0, fmt.Errorf("TCP path MTU discovery: last error: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(msg string)
|
||||||
|
Debugf(msg string, args ...any)
|
||||||
|
Warnf(msg string, args ...any)
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/pmtud/tcp (interfaces: Logger)
|
||||||
|
|
||||||
|
// Package tcp is a generated GoMock package.
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug mocks base method.
|
||||||
|
func (m *MockLogger) Debug(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Debug", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug indicates an expected call of Debug.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf mocks base method.
|
||||||
|
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Debugf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf indicates an expected call of Debugf.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf mocks base method.
|
||||||
|
func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Warnf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf indicates an expected call of Warnf.
|
||||||
|
func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMTUNotFound = errors.New("MTU not found")
|
||||||
|
|
||||||
|
type testUnit struct {
|
||||||
|
mtu uint32
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func PathMTUDiscover(ctx context.Context, addrPort netip.AddrPort,
|
||||||
|
minMTU, maxPossibleMTU uint32, logger Logger,
|
||||||
|
) (mtu uint32, err error) {
|
||||||
|
mtusToTest := test.MakeMTUsToTest(minMTU, maxPossibleMTU)
|
||||||
|
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||||
|
return minMTU, nil
|
||||||
|
}
|
||||||
|
logger.Debugf("TCP testing the following MTUs: %v", mtusToTest)
|
||||||
|
|
||||||
|
tests := make([]testUnit, len(mtusToTest))
|
||||||
|
for i := range mtusToTest {
|
||||||
|
tests[i] = testUnit{mtu: mtusToTest[i]}
|
||||||
|
}
|
||||||
|
|
||||||
|
family := syscall.AF_INET
|
||||||
|
if addrPort.Addr().Is6() {
|
||||||
|
family = syscall.AF_INET6
|
||||||
|
}
|
||||||
|
fd, stop, err := startRawSocket(family)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("starting raw socket: %w", err)
|
||||||
|
}
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
tracker := newTracker(fd, addrPort.Addr().Is4())
|
||||||
|
|
||||||
|
const timeout = time.Second
|
||||||
|
runCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
errCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errCh <- tracker.listen(runCtx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
for i := range tests {
|
||||||
|
go func(i int) {
|
||||||
|
err := runTest(runCtx, fd, tracker, addrPort, tests[i].mtu)
|
||||||
|
tests[i].ok = err == nil
|
||||||
|
doneCh <- struct{}{}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range tests {
|
||||||
|
select {
|
||||||
|
case <-doneCh:
|
||||||
|
case err := <-errCh:
|
||||||
|
if err == nil { // timeout
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("listening for TCP replies: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tests[len(tests)-1].ok {
|
||||||
|
return tests[len(tests)-1].mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||||
|
if tests[i].ok {
|
||||||
|
stop()
|
||||||
|
cancel()
|
||||||
|
return PathMTUDiscover(ctx, addrPort,
|
||||||
|
tests[i].mtu, tests[i+1].mtu-1, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("%w: your connection might not be working at all", ErrMTUNotFound)
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSYNPacket creates a TCP SYN packet for initiating a handshake.
|
||||||
|
// SYN packets have normally no data payload, so you SHOULD set mtu to 0.
|
||||||
|
// However, in some cases where the server closes the connection with RST immediately,
|
||||||
|
// it can be useful to add some data payload to a SYN packet and check if the server still
|
||||||
|
// replies. Only set mtu to a non zero value if you know what you are doing.
|
||||||
|
func createSYNPacket(src, dst netip.AddrPort, mtu uint32) (packet []byte, seq uint32) {
|
||||||
|
seq = rand.Uint32() //nolint:gosec
|
||||||
|
const ack = 0 // SYN has no ACK number
|
||||||
|
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||||
|
if mtu > 0 {
|
||||||
|
payloadLength = getPayloadLength(mtu, dst)
|
||||||
|
}
|
||||||
|
return createPacket(src, dst, seq, ack, payloadLength, synFlag), seq
|
||||||
|
}
|
||||||
|
|
||||||
|
// createACKPacket creates a TCP ACK packet.
|
||||||
|
// If the mtu is set to 0, no payload is sent.
|
||||||
|
// Otherwise, the payload is calculated to test the MTU given.
|
||||||
|
func createACKPacket(src, dst netip.AddrPort, seq, ack uint32, mtu uint32) []byte {
|
||||||
|
payloadLength := constants.BaseTCPHeaderLength // no data payload
|
||||||
|
if mtu > 0 {
|
||||||
|
payloadLength = getPayloadLength(mtu, dst)
|
||||||
|
}
|
||||||
|
const flags = ackFlag | pshFlag
|
||||||
|
return createPacket(src, dst, seq, ack, payloadLength, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRSTPacket(src, dst netip.AddrPort, seq, ack uint32) []byte {
|
||||||
|
const payloadLength = constants.BaseTCPHeaderLength // no data payload
|
||||||
|
return createPacket(src, dst, seq, ack, payloadLength, rstFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPayloadLength(mtu uint32, dst netip.AddrPort) uint32 {
|
||||||
|
var ipHeaderLength uint32
|
||||||
|
if dst.Addr().Is4() {
|
||||||
|
ipHeaderLength = constants.IPv4HeaderLength
|
||||||
|
} else {
|
||||||
|
ipHeaderLength = constants.IPv6HeaderLength
|
||||||
|
}
|
||||||
|
if mtu < ipHeaderLength+constants.BaseTCPHeaderLength {
|
||||||
|
panic("MTU too small to hold IP and TCP headers")
|
||||||
|
}
|
||||||
|
return mtu - ipHeaderLength
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPacket(src, dst netip.AddrPort,
|
||||||
|
seq, ack, payloadLength uint32, flags byte,
|
||||||
|
) []byte {
|
||||||
|
if payloadLength < constants.BaseTCPHeaderLength {
|
||||||
|
panic("payload length is too small to hold TCP header")
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipHeader []byte
|
||||||
|
if dst.Addr().Is4() {
|
||||||
|
ipHeader = ip.HeaderV4(src.Addr(), dst.Addr(), payloadLength)
|
||||||
|
} else {
|
||||||
|
ipHeader = ip.HeaderV6(src.Addr(), dst.Addr(),
|
||||||
|
uint16(payloadLength), byte(syscall.IPPROTO_TCP)) //nolint:gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpHeader := makeTCPHeader(src.Port(), dst.Port(), seq, ack, flags)
|
||||||
|
|
||||||
|
// data is just zeroes
|
||||||
|
dataLength := int(payloadLength) - int(constants.BaseTCPHeaderLength)
|
||||||
|
var data []byte
|
||||||
|
if dataLength > 0 {
|
||||||
|
data = make([]byte, dataLength)
|
||||||
|
}
|
||||||
|
checksum := tcpChecksum(ipHeader, tcpHeader, data)
|
||||||
|
tcpHeader[16] = byte(checksum >> 8) //nolint:mnd
|
||||||
|
tcpHeader[17] = byte(checksum & 0xff) //nolint:mnd
|
||||||
|
|
||||||
|
packet := make([]byte, len(ipHeader)+int(constants.BaseTCPHeaderLength)+dataLength)
|
||||||
|
copy(packet, ipHeader)
|
||||||
|
copy(packet[len(ipHeader):], tcpHeader)
|
||||||
|
copy(packet[len(ipHeader)+int(constants.BaseTCPHeaderLength):], data)
|
||||||
|
return packet
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/ip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func startRawSocket(family int) (fd fileDescriptor, stop func(), err error) {
|
||||||
|
fdPlatform, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_TCP)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("creating raw socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == syscall.AF_INET {
|
||||||
|
err = ip.SetIPv4HeaderIncluded(fdPlatform)
|
||||||
|
} else {
|
||||||
|
err = ip.SetIPv6HeaderIncluded(fdPlatform)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
_ = syscall.Close(fdPlatform)
|
||||||
|
return 0, nil, fmt.Errorf("setting header option on raw socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow sending packets larger than cached PMTU (for PMTUD probing)
|
||||||
|
err = setMTUDiscovery(fdPlatform)
|
||||||
|
if err != nil {
|
||||||
|
_ = syscall.Close(fdPlatform)
|
||||||
|
return 0, nil, fmt.Errorf("setting IP_MTU_DISCOVER: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use polling because some Linux systems do not cancel
|
||||||
|
// blocking syscalls such as recvfrom when the socket is closed,
|
||||||
|
// which would cause things to hang indefinitely.
|
||||||
|
err = setNonBlock(fdPlatform)
|
||||||
|
if err != nil {
|
||||||
|
_ = syscall.Close(fdPlatform)
|
||||||
|
return 0, nil, fmt.Errorf("setting non-blocking mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stop = func() {
|
||||||
|
_ = syscall.Close(fdPlatform)
|
||||||
|
}
|
||||||
|
return fileDescriptor(fdPlatform), stop, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTCPPacketNotSynAck = errors.New("TCP packet is not a SYN-ACK")
|
||||||
|
errTCPSynAckAckMismatch = errors.New("TCP SYN-ACK ACK number does not match expected value")
|
||||||
|
errFinalPacketTypeUnexpected = errors.New("final TCP packet type is unexpected")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Craft and send a raw TCP packet to test the MTU.
|
||||||
|
// It expects either an RST reply (if no server is listening)
|
||||||
|
// or a SYN-ACK/ACK reply (if a server is listening).
|
||||||
|
func runTest(ctx context.Context, fd fileDescriptor,
|
||||||
|
tracker *tracker, dst netip.AddrPort, mtu uint32,
|
||||||
|
) error {
|
||||||
|
const proto = syscall.IPPROTO_TCP
|
||||||
|
src, cleanup, err := ip.SrcAddr(dst, proto)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting source address: %w", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ch := make(chan []byte)
|
||||||
|
abort := make(chan struct{})
|
||||||
|
defer close(abort)
|
||||||
|
tracker.register(src.Port(), dst.Port(), ch, abort)
|
||||||
|
defer tracker.unregister(src.Port(), dst.Port())
|
||||||
|
|
||||||
|
dstSockAddr := makeSockAddr(dst)
|
||||||
|
|
||||||
|
synPacket, synSeq := createSYNPacket(src, dst, 0)
|
||||||
|
const sendToFlags = 0
|
||||||
|
err = sendTo(fd, synPacket, sendToFlags, dstSockAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sending SYN packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reply []byte
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case reply = <-ch:
|
||||||
|
}
|
||||||
|
|
||||||
|
packetType, synAckSeq, synAckAck, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return fmt.Errorf("parsing first reply TCP header: %w", err)
|
||||||
|
case packetType == packetTypeRST:
|
||||||
|
// server actively closed the connection, try sending a SYN with data
|
||||||
|
return handleRSTReply(ctx, fd, ch, src, dst, mtu)
|
||||||
|
case packetType != packetTypeSYNACK:
|
||||||
|
return fmt.Errorf("%w: unexpected packet type %s", errTCPPacketNotSynAck, packetType)
|
||||||
|
case synAckAck != synSeq+1:
|
||||||
|
return fmt.Errorf("%w: expected %d, got %d", errTCPSynAckAckMismatch, synSeq+1, synAckAck)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a no-data ACK packet to finish the 3-way handshake.
|
||||||
|
const ackMTU = 0 // no data payload initially
|
||||||
|
ackPacket := createACKPacket(src, dst, synAckAck, synAckSeq+1, ackMTU)
|
||||||
|
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sending ACK-without-data packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a data ACK packet to test the MTU given.
|
||||||
|
ackPacket = createACKPacket(src, dst, synAckAck, synAckSeq+1, mtu)
|
||||||
|
err = sendTo(fd, ackPacket, sendToFlags, dstSockAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sending ACK-with-data packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case reply = <-ch:
|
||||||
|
}
|
||||||
|
|
||||||
|
packetType, _, ack, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing second reply TCP header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch packetType { //nolint:exhaustive
|
||||||
|
case packetTypeRST:
|
||||||
|
return nil
|
||||||
|
case packetTypeACK:
|
||||||
|
err = sendRST(fd, src, dst, ack)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sending RST packet: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
_ = sendRST(fd, src, dst, ack)
|
||||||
|
return fmt.Errorf("%w: %s", errFinalPacketTypeUnexpected, packetType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSockAddr(addr netip.AddrPort) syscall.Sockaddr {
|
||||||
|
if addr.Addr().Is4() {
|
||||||
|
return &syscall.SockaddrInet4{
|
||||||
|
Port: int(addr.Port()),
|
||||||
|
Addr: addr.Addr().As4(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &syscall.SockaddrInet6{
|
||||||
|
Port: int(addr.Port()),
|
||||||
|
Addr: addr.Addr().As16(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errTCPPacketNotRST = errors.New("TCP packet is not an RST")
|
||||||
|
|
||||||
|
func handleRSTReply(ctx context.Context, fd fileDescriptor, ch <-chan []byte,
|
||||||
|
src, dst netip.AddrPort, mtu uint32,
|
||||||
|
) error {
|
||||||
|
packet, _ := createSYNPacket(src, dst, mtu)
|
||||||
|
const sendToFlags = 0
|
||||||
|
err := sendTo(fd, packet, sendToFlags, makeSockAddr(dst))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sending SYN MTU-test packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reply []byte
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err() // timeout: the MTU test SYN packet was too big
|
||||||
|
case reply = <-ch:
|
||||||
|
}
|
||||||
|
|
||||||
|
packetType, _, _, err := parseTCPHeader(reply[:constants.BaseTCPHeaderLength])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing reply TCP header: %w", err)
|
||||||
|
} else if packetType != packetTypeRST {
|
||||||
|
return fmt.Errorf("%w: %s", errTCPPacketNotRST, packetType)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendRST(fd fileDescriptor, src, dst netip.AddrPort,
|
||||||
|
previousACK uint32,
|
||||||
|
) error {
|
||||||
|
seq := previousACK
|
||||||
|
const ack = 0
|
||||||
|
rstPacket := createRSTPacket(src, dst, seq, ack)
|
||||||
|
const sendToFlags = 0
|
||||||
|
return sendTo(fd, rstPacket, sendToFlags, makeSockAddr(dst))
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||||
|
return reply, true
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func setMTUDiscovery(fd int) error {
|
||||||
|
return syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
//go:build !darwin
|
||||||
|
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func stripIPv4Header(reply []byte) (result []byte, ok bool) {
|
||||||
|
if len(reply) < int(constants.IPv4HeaderLength) {
|
||||||
|
return nil, false // not an IPv4 packet
|
||||||
|
}
|
||||||
|
|
||||||
|
version := reply[0] >> 4 //nolint:mnd
|
||||||
|
const ipv4Version = 4
|
||||||
|
if version != ipv4Version {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// For IPv4 we need to skip the IP header, which is at least
|
||||||
|
// 20B and can be up to 60B.
|
||||||
|
// The Internet Header Length is the lower 4 bits of the first byte and
|
||||||
|
// represents the number of 32-bit words of the header length.
|
||||||
|
const ihlMask byte = 0x0F
|
||||||
|
const bytesInWord = 4
|
||||||
|
headerLength := int((reply[0] & ihlMask)) * bytesInWord
|
||||||
|
if len(reply) < headerLength {
|
||||||
|
return nil, false // not enough data for full IPv4 header
|
||||||
|
}
|
||||||
|
return reply[headerLength:], true
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
|
"github.com/qdm12/log"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_runTest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
noopLogger := &noopLogger{}
|
||||||
|
netlinker := netlink.New(noopLogger)
|
||||||
|
loopbackMTU, err := findLoopbackMTU(netlinker)
|
||||||
|
require.NoError(t, err, "finding loopback IPv4 MTU")
|
||||||
|
defaultIPv4MTU, err := findDefaultIPv4RouteMTU(netlinker)
|
||||||
|
require.NoError(t, err, "finding default IPv4 route MTU")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
|
const family = syscall.AF_INET
|
||||||
|
fd, stop, err := startRawSocket(family)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const ipv4 = true
|
||||||
|
tracker := newTracker(fd, ipv4)
|
||||||
|
trackerCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
trackerCh <- tracker.listen(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
stop()
|
||||||
|
cancel() // stop listening
|
||||||
|
err = <-trackerCh
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
timeout time.Duration
|
||||||
|
dst func(t *testing.T) netip.AddrPort
|
||||||
|
mtu uint32
|
||||||
|
success bool
|
||||||
|
}{
|
||||||
|
"local_not_listening": {
|
||||||
|
timeout: time.Hour,
|
||||||
|
dst: func(t *testing.T) netip.AddrPort {
|
||||||
|
t.Helper()
|
||||||
|
port := reserveClosedPort(t)
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), port)
|
||||||
|
},
|
||||||
|
mtu: loopbackMTU,
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
"remote_not_listening": {
|
||||||
|
timeout: 50 * time.Millisecond,
|
||||||
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 12345)
|
||||||
|
},
|
||||||
|
mtu: defaultIPv4MTU,
|
||||||
|
},
|
||||||
|
"1.1.1.1:443": {
|
||||||
|
timeout: time.Second,
|
||||||
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||||
|
},
|
||||||
|
mtu: defaultIPv4MTU,
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
"1.1.1.1:80": {
|
||||||
|
timeout: time.Second,
|
||||||
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 80)
|
||||||
|
},
|
||||||
|
mtu: defaultIPv4MTU,
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
"8.8.8.8:443": {
|
||||||
|
timeout: time.Second,
|
||||||
|
dst: func(_ *testing.T) netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 443)
|
||||||
|
},
|
||||||
|
mtu: defaultIPv4MTU,
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), testCase.timeout)
|
||||||
|
defer cancel()
|
||||||
|
dst := testCase.dst(t)
|
||||||
|
err := runTest(ctx, fd, tracker, dst, testCase.mtu)
|
||||||
|
if testCase.success {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errRouteNotFound = errors.New("route not found")
|
||||||
|
|
||||||
|
func findLoopbackMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
|
routes, err := netlinker.RouteList(netlink.FamilyV4)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting routes list: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Dst.IsValid() && route.Dst.Addr().IsLoopback() {
|
||||||
|
link, err := netlinker.LinkByIndex(route.LinkIndex)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting link by index: %w", err)
|
||||||
|
}
|
||||||
|
// Quirk: make sure it is maximum 65535, and not i.e. 65536
|
||||||
|
// or the IP header 16 bits will fail to fit that packet length value.
|
||||||
|
const maxMTU = 65535
|
||||||
|
return min(link.MTU, maxMTU), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w: no loopback route found", errRouteNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findDefaultIPv4RouteMTU(netlinker *netlink.NetLink) (mtu uint32, err error) {
|
||||||
|
noopLogger := &noopLogger{}
|
||||||
|
routing := routing.New(netlinker, noopLogger)
|
||||||
|
defaultRoutes, err := routing.DefaultRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting default routes: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range defaultRoutes {
|
||||||
|
if route.Family != netlink.FamilyV4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
link, err := netlinker.LinkByName(defaultRoutes[0].NetInterface)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting link by name: %w", err)
|
||||||
|
}
|
||||||
|
return link.MTU, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%w: no default route found", errRouteNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reserveClosedPort(t *testing.T) (port uint16) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := syscall.Close(fd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := &syscall.SockaddrInet4{
|
||||||
|
Port: 0,
|
||||||
|
Addr: [4]byte{127, 0, 0, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = syscall.Bind(fd, addr)
|
||||||
|
if err != nil {
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sockAddr, err := syscall.Getsockname(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sockAddr4, ok := sockAddr.(*syscall.SockaddrInet4)
|
||||||
|
if !ok {
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
t.Fatal("not an IPv4 address")
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(sockAddr4.Port) //nolint:gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopLogger struct{}
|
||||||
|
|
||||||
|
func (l *noopLogger) Patch(_ ...log.Option) {}
|
||||||
|
func (l *noopLogger) Debug(_ string) {}
|
||||||
|
func (l *noopLogger) Debugf(_ string, _ ...any) {}
|
||||||
|
func (l *noopLogger) Info(_ string) {}
|
||||||
|
func (l *noopLogger) Warn(_ string) {}
|
||||||
|
func (l *noopLogger) Error(_ string) {}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fileDescriptor is a platform-independent type for socket file descriptors.
|
||||||
|
type fileDescriptor int
|
||||||
|
|
||||||
|
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||||
|
return syscall.Sendto(int(fd), p, flags, to)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||||
|
timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
|
||||||
|
return syscall.SetsockoptTimeval(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &timeval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||||
|
return syscall.Recvfrom(int(fd), p, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setNonBlock(fd int) error {
|
||||||
|
return syscall.SetNonblock(fd, true)
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !windows
|
||||||
|
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
func setMTUDiscovery(fd int) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileDescriptor syscall.Handle
|
||||||
|
|
||||||
|
func sendTo(fd fileDescriptor, p []byte, flags int, to syscall.Sockaddr) (err error) {
|
||||||
|
return syscall.Sendto(syscall.Handle(fd), p, flags, to)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSocketTimeout(fd fileDescriptor, timeout time.Duration) (err error) {
|
||||||
|
timeval := int(timeout.Milliseconds())
|
||||||
|
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, windows.SO_RCVTIMEO, timeval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recvFrom(fd fileDescriptor, p []byte, flags int) (n int, from syscall.Sockaddr, err error) {
|
||||||
|
return syscall.Recvfrom(syscall.Handle(fd), p, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMTUDiscovery(fd syscall.Handle) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setNonBlock(fd syscall.Handle) error {
|
||||||
|
// Windows: Use ioctlsocket with FIONBIO
|
||||||
|
var arg uint32 = 1 // 1 to enable non-blocking mode
|
||||||
|
var bytesReturned uint32
|
||||||
|
const FIONBIO = 0x8004667e
|
||||||
|
return syscall.WSAIoctl(fd, FIONBIO, (*byte)(unsafe.Pointer(&arg)),
|
||||||
|
uint32(unsafe.Sizeof(arg)), nil, 0, &bytesReturned, nil, 0)
|
||||||
|
}
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
// For SYN, ack is 0.
|
||||||
|
// For SYN-ACK, ack is the sequence number + 1 of the SYN.
|
||||||
|
func makeTCPHeader(srcPort, dstPort uint16, seq, ack uint32, flags byte) []byte {
|
||||||
|
header := make([]byte, constants.BaseTCPHeaderLength)
|
||||||
|
binary.BigEndian.PutUint16(header[0:], srcPort)
|
||||||
|
binary.BigEndian.PutUint16(header[2:], dstPort)
|
||||||
|
binary.BigEndian.PutUint32(header[4:], seq)
|
||||||
|
binary.BigEndian.PutUint32(header[8:], ack)
|
||||||
|
//nolint:mnd
|
||||||
|
header[12] = byte(constants.BaseTCPHeaderLength) << 2 // data offset
|
||||||
|
header[13] = flags
|
||||||
|
// windowSize can be left to 5840 even for IPv6, it doesn't matter.
|
||||||
|
const windowSize = 5840
|
||||||
|
binary.BigEndian.PutUint16(header[14:], windowSize)
|
||||||
|
// header[16:17] is the checksum, set later
|
||||||
|
// header[18:19] is urgent pointer, not needed for our use case
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:mnd
|
||||||
|
func tcpChecksum(ipHeader, tcpHeader, payload []byte) uint16 {
|
||||||
|
var pseudoHeader []byte
|
||||||
|
isIPv6 := len(ipHeader) >= 40 && (ipHeader[0]>>4) == 6
|
||||||
|
if isIPv6 {
|
||||||
|
pseudoHeader = make([]byte, 40)
|
||||||
|
copy(pseudoHeader[0:16], ipHeader[8:24]) // Source Address
|
||||||
|
copy(pseudoHeader[16:32], ipHeader[24:40]) // Destination Address
|
||||||
|
totalLength := uint32(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||||
|
binary.BigEndian.PutUint32(pseudoHeader[32:], totalLength)
|
||||||
|
pseudoHeader[39] = 6 // Next Header (TCP)
|
||||||
|
} else {
|
||||||
|
pseudoHeader = make([]byte, 12)
|
||||||
|
copy(pseudoHeader[0:4], ipHeader[12:16])
|
||||||
|
copy(pseudoHeader[4:8], ipHeader[16:20])
|
||||||
|
pseudoHeader[9] = 6
|
||||||
|
totalLength := uint16(len(tcpHeader) + len(payload)) //nolint:gosec
|
||||||
|
binary.BigEndian.PutUint16(pseudoHeader[10:], totalLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := uint32(0)
|
||||||
|
for _, slice := range [][]byte{pseudoHeader, tcpHeader, payload} {
|
||||||
|
for i := 0; i < len(slice)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(slice[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(slice)%2 != 0 {
|
||||||
|
sum += uint32(slice[len(slice)-1]) << 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (sum >> 16) > 0 {
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum) //nolint:gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpFlagsOffset = 13
|
||||||
|
rstFlag byte = 0x04
|
||||||
|
synFlag byte = 0x02
|
||||||
|
ackFlag byte = 0x10
|
||||||
|
pshFlag byte = 0x08
|
||||||
|
)
|
||||||
|
|
||||||
|
type packetType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetTypeSYN packetType = iota + 1
|
||||||
|
packetTypeSYNACK
|
||||||
|
packetTypeACK
|
||||||
|
packetTypeRST
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p packetType) String() string {
|
||||||
|
switch p {
|
||||||
|
case packetTypeSYN:
|
||||||
|
return "SYN"
|
||||||
|
case packetTypeSYNACK:
|
||||||
|
return "SYN-ACK"
|
||||||
|
case packetTypeACK:
|
||||||
|
return "ACK"
|
||||||
|
case packetTypeRST:
|
||||||
|
return "RST"
|
||||||
|
default:
|
||||||
|
panic("unknown packet type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTCPHeaderTooShort = errors.New("TCP header is too short")
|
||||||
|
errTCPPacketTypeUnknown = errors.New("TCP packet type is unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseTCPHeader parses some elements from the TCP header.
|
||||||
|
func parseTCPHeader(header []byte) (packetType packetType, seq, ack uint32, err error) {
|
||||||
|
if len(header) < int(constants.BaseTCPHeaderLength) {
|
||||||
|
return 0, 0, 0, fmt.Errorf("%w: %d bytes", errTCPHeaderTooShort, len(header))
|
||||||
|
}
|
||||||
|
flags := header[tcpFlagsOffset]
|
||||||
|
switch {
|
||||||
|
case (flags&synFlag) != 0 && (flags&ackFlag) == 0:
|
||||||
|
packetType = packetTypeSYN
|
||||||
|
case (flags&synFlag) != 0 && (flags&ackFlag) != 0:
|
||||||
|
packetType = packetTypeSYNACK
|
||||||
|
case (flags & rstFlag) != 0:
|
||||||
|
packetType = packetTypeRST
|
||||||
|
case (flags & ackFlag) != 0:
|
||||||
|
packetType = packetTypeACK
|
||||||
|
default:
|
||||||
|
return 0, 0, 0, fmt.Errorf("%w: flags are 0x%02x", errTCPPacketTypeUnknown, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
seq = binary.BigEndian.Uint32(header[4:8])
|
||||||
|
ack = binary.BigEndian.Uint32(header[8:12])
|
||||||
|
return packetType, seq, ack, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tracker struct {
|
||||||
|
fd fileDescriptor
|
||||||
|
ipv4 bool
|
||||||
|
mutex sync.RWMutex
|
||||||
|
portsToDispatch map[uint32]dispatch
|
||||||
|
}
|
||||||
|
|
||||||
|
type dispatch struct {
|
||||||
|
replyCh chan<- []byte
|
||||||
|
abort <-chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTracker(fd fileDescriptor, ipv4 bool) *tracker {
|
||||||
|
return &tracker{
|
||||||
|
fd: fd,
|
||||||
|
ipv4: ipv4,
|
||||||
|
portsToDispatch: make(map[uint32]dispatch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracker) constructKey(localPort, remotePort uint16) uint32 {
|
||||||
|
buf := make([]byte, 4) //nolint:mnd
|
||||||
|
binary.BigEndian.PutUint16(buf[0:2], localPort)
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], remotePort)
|
||||||
|
return binary.BigEndian.Uint32(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracker) register(localPort, remotePort uint16,
|
||||||
|
ch chan<- []byte, abort <-chan struct{},
|
||||||
|
) {
|
||||||
|
key := t.constructKey(localPort, remotePort)
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
t.portsToDispatch[key] = dispatch{
|
||||||
|
replyCh: ch,
|
||||||
|
abort: abort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tracker) unregister(localPort, remotePort uint16) {
|
||||||
|
key := t.constructKey(localPort, remotePort)
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
delete(t.portsToDispatch, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// listen listens for incoming TCP packets and dispatches them to the
|
||||||
|
// correct channel based on the source and destination port.
|
||||||
|
// If the context has a deadline associated, this one is used on the socket.
|
||||||
|
// Note it returns a nil error on context cancellation.
|
||||||
|
func (t *tracker) listen(ctx context.Context) error {
|
||||||
|
deadline, hasDeadline := ctx.Deadline()
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
if hasDeadline {
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := setSocketTimeout(t.fd, remaining)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting socket receive timeout: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reply := make([]byte, constants.MaxEthernetFrameSize)
|
||||||
|
n, _, err := recvFrom(t.fd, reply, 0)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, syscall.EAGAIN),
|
||||||
|
errors.Is(err, syscall.EWOULDBLOCK):
|
||||||
|
pollSleep(ctx)
|
||||||
|
continue
|
||||||
|
case ctx.Err() != nil:
|
||||||
|
// context canceled, stop listening so exit cleanly with no error
|
||||||
|
return nil //nolint:nilerr
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("receiving on socket: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reply = reply[:n]
|
||||||
|
|
||||||
|
if t.ipv4 {
|
||||||
|
var ok bool
|
||||||
|
reply, ok = stripIPv4Header(reply)
|
||||||
|
if !ok {
|
||||||
|
continue // not an IPv4 packet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const minTCPHeaderLength = 20
|
||||||
|
if len(reply) < minTCPHeaderLength {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
srcPort := binary.BigEndian.Uint16(reply[0:2])
|
||||||
|
dstPort := binary.BigEndian.Uint16(reply[2:4])
|
||||||
|
key := t.constructKey(dstPort, srcPort)
|
||||||
|
t.mutex.RLock()
|
||||||
|
dispatch, exists := t.portsToDispatch[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
if !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case dispatch.replyCh <- reply:
|
||||||
|
case <-dispatch.abort:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pollSleep(ctx context.Context) {
|
||||||
|
const sleepBetweenPolls = 10 * time.Millisecond
|
||||||
|
timer := time.NewTimer(sleepBetweenPolls)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
// MakeMTUsToTest determines a slice of MTU values to test
|
||||||
|
// between minMTU and maxMTU inclusive. It creates an MTU
|
||||||
|
// slice of length up to 11 MTUs such that:
|
||||||
|
// - the first element is the minMTU
|
||||||
|
// - the last element is the maxMTU
|
||||||
|
// - elements in-between are separated as close to each other
|
||||||
|
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||||
|
// with a total search space of 1728 MTUs which is enough;
|
||||||
|
// to find it in 2 searches requires 37 parallel queries which
|
||||||
|
// could be blocked by firewalls.
|
||||||
|
func MakeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
|
||||||
|
const mtusLength = 11 // find the final MTU in 3 searches
|
||||||
|
diff := maxMTU - minMTU
|
||||||
|
switch {
|
||||||
|
case minMTU > maxMTU:
|
||||||
|
panic("minMTU > maxMTU")
|
||||||
|
case diff <= mtusLength:
|
||||||
|
mtus = make([]uint32, 0, diff)
|
||||||
|
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||||
|
mtus = append(mtus, mtu)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
step := float64(diff) / float64(mtusLength-1)
|
||||||
|
mtus = make([]uint32, 0, mtusLength)
|
||||||
|
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||||
|
mtus = append(mtus, uint32(math.Round(mtu)))
|
||||||
|
}
|
||||||
|
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtus
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package pmtud
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_makeMTUsToTest(t *testing.T) {
|
func Test_MakeMTUsToTest(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
@@ -48,7 +48,7 @@ func Test_makeMTUsToTest(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
mtus := MakeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||||
assert.Equal(t, testCase.mtus, mtus)
|
assert.Equal(t, testCase.mtus, mtus)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package pmtud
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||||
|
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxTheoreticalVPNMTU returns the theoretical maximum MTU for a VPN tunnel
|
||||||
|
// given the VPN type, network protocol, and VPN gateway IP address.
|
||||||
|
// This is notably useful to skip testing MTU values higher than this value.
|
||||||
|
// The function panics if the network or VPN type is unknown.
|
||||||
|
func MaxTheoreticalVPNMTU(vpnType, network string, vpnGateway netip.Addr) uint32 {
|
||||||
|
const physicalLinkMTU = pconstants.MaxEthernetFrameSize
|
||||||
|
vpnLinkMTU := physicalLinkMTU
|
||||||
|
if vpnGateway.Is4() {
|
||||||
|
vpnLinkMTU -= pconstants.IPv4HeaderLength
|
||||||
|
} else {
|
||||||
|
vpnLinkMTU -= pconstants.IPv6HeaderLength
|
||||||
|
}
|
||||||
|
switch network {
|
||||||
|
case constants.TCP:
|
||||||
|
vpnLinkMTU -= pconstants.BaseTCPHeaderLength
|
||||||
|
case constants.UDP:
|
||||||
|
vpnLinkMTU -= pconstants.UDPHeaderLength
|
||||||
|
default:
|
||||||
|
panic("unknown network protocol: " + network)
|
||||||
|
}
|
||||||
|
switch vpnType {
|
||||||
|
case vpn.Wireguard:
|
||||||
|
vpnLinkMTU -= pconstants.WireguardHeaderLength
|
||||||
|
case vpn.OpenVPN:
|
||||||
|
vpnLinkMTU -= pconstants.OpenVPNHeaderMaxLength
|
||||||
|
default:
|
||||||
|
panic("unknown VPN type: " + vpnType)
|
||||||
|
}
|
||||||
|
return vpnLinkMTU
|
||||||
|
}
|
||||||
@@ -16,7 +16,17 @@ func BuildWireguardSettings(connection models.Connection,
|
|||||||
settings.PreSharedKey = *userSettings.PreSharedKey
|
settings.PreSharedKey = *userSettings.PreSharedKey
|
||||||
settings.InterfaceName = userSettings.Interface
|
settings.InterfaceName = userSettings.Interface
|
||||||
settings.Implementation = userSettings.Implementation
|
settings.Implementation = userSettings.Implementation
|
||||||
settings.MTU = userSettings.MTU
|
if *userSettings.MTU > 0 {
|
||||||
|
settings.MTU = *userSettings.MTU
|
||||||
|
} else {
|
||||||
|
// The default is 1320 which is NOT the wireguard-go default
|
||||||
|
// of 1420 because this impacts bandwidth a lot on some
|
||||||
|
// VPN providers, see https://github.com/qdm12/gluetun/issues/1650.
|
||||||
|
// It has been lowered to 1320 following quite a bit of
|
||||||
|
// investigation in the issue: https://github.com/qdm12/gluetun/issues/2533.
|
||||||
|
const defaultMTU = 1320
|
||||||
|
settings.MTU = defaultMTU
|
||||||
|
}
|
||||||
settings.IPv6 = &ipv6Supported
|
settings.IPv6 = &ipv6Supported
|
||||||
|
|
||||||
const rulePriority = 101 // 100 is to receive external connections
|
const rulePriority = 101 // 100 is to receive external connections
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
ipv6Supported bool
|
ipv6Supported bool
|
||||||
settings wireguard.Settings
|
settings wireguard.Settings
|
||||||
}{
|
}{
|
||||||
"some settings": {
|
"some_settings": {
|
||||||
connection: models.Connection{
|
connection: models.Connection{
|
||||||
IP: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
|
IP: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
|
||||||
Port: 51821,
|
Port: 51821,
|
||||||
@@ -41,6 +41,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
},
|
},
|
||||||
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||||
Interface: "wg1",
|
Interface: "wg1",
|
||||||
|
MTU: ptrTo(uint32(1000)),
|
||||||
},
|
},
|
||||||
ipv6Supported: false,
|
ipv6Supported: false,
|
||||||
settings: wireguard.Settings{
|
settings: wireguard.Settings{
|
||||||
@@ -58,6 +59,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
PersistentKeepaliveInterval: time.Hour,
|
PersistentKeepaliveInterval: time.Hour,
|
||||||
RulePriority: 101,
|
RulePriority: 101,
|
||||||
IPv6: boolPtr(false),
|
IPv6: boolPtr(false),
|
||||||
|
MTU: 1000,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,3 +47,26 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
|||||||
}
|
}
|
||||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrVPNRouteNotFound = errors.New("VPN route not found")
|
||||||
|
|
||||||
|
func (r *Routing) VPNRoute(vpnIntf string) (route netlink.Route, err error) {
|
||||||
|
vpnLink, err := r.netLinker.LinkByName(vpnIntf)
|
||||||
|
if err != nil {
|
||||||
|
return route, fmt.Errorf("finding link %s: %w", vpnIntf, err)
|
||||||
|
}
|
||||||
|
vpnLinkIndex := vpnLink.Index
|
||||||
|
|
||||||
|
routes, err := r.netLinker.RouteList(netlink.FamilyAll)
|
||||||
|
if err != nil {
|
||||||
|
return route, fmt.Errorf("listing routes: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.LinkIndex == vpnLinkIndex &&
|
||||||
|
!route.Dst.IsValid() {
|
||||||
|
return route, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return route, fmt.Errorf("%w: for interface %s in %d routes",
|
||||||
|
ErrVPNRouteNotFound, vpnIntf, len(routes))
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type Firewall interface {
|
|||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
|
||||||
|
VPNRoute(vpnIntf string) (route netlink.Route, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PortForward interface {
|
type PortForward interface {
|
||||||
@@ -67,6 +68,7 @@ type NetLinker interface {
|
|||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
RouteAdd(route netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
|
RouteReplace(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
|
|||||||
+7
-1
@@ -47,7 +47,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tunnelUpData := tunnelUpData{
|
tunnelUpData := tunnelUpData{
|
||||||
vpnType: settings.Type,
|
pmtud: tunnelUpPMTUDData{
|
||||||
|
enabled: settings.Type != vpn.Wireguard || *settings.Wireguard.MTU == 0,
|
||||||
|
vpnType: settings.Type,
|
||||||
|
network: connection.Protocol,
|
||||||
|
icmpAddrs: settings.PMTUD.ICMPAddresses,
|
||||||
|
tcpAddrs: settings.PMTUD.TCPAddresses,
|
||||||
|
},
|
||||||
serverIP: connection.IP,
|
serverIP: connection.IP,
|
||||||
serverName: connection.ServerName,
|
serverName: connection.ServerName,
|
||||||
canPortForward: connection.PortForward,
|
canPortForward: connection.PortForward,
|
||||||
|
|||||||
+64
-32
@@ -2,7 +2,6 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
@@ -10,6 +9,7 @@ import (
|
|||||||
"github.com/qdm12/dns/v2/pkg/check"
|
"github.com/qdm12/dns/v2/pkg/check"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud"
|
"github.com/qdm12/gluetun/internal/pmtud"
|
||||||
|
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
"github.com/qdm12/gluetun/internal/version"
|
"github.com/qdm12/gluetun/internal/version"
|
||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
@@ -17,9 +17,7 @@ import (
|
|||||||
type tunnelUpData struct {
|
type tunnelUpData struct {
|
||||||
// Healthcheck
|
// Healthcheck
|
||||||
serverIP netip.Addr
|
serverIP netip.Addr
|
||||||
// vpnType is used for path MTU discovery to find the protocol overhead.
|
pmtud tunnelUpPMTUDData
|
||||||
// It can be "wireguard" or "openvpn".
|
|
||||||
vpnType string
|
|
||||||
// Port forwarding
|
// Port forwarding
|
||||||
vpnIntf string
|
vpnIntf string
|
||||||
serverName string // used for PIA
|
serverName string // used for PIA
|
||||||
@@ -29,6 +27,23 @@ type tunnelUpData struct {
|
|||||||
portForwarder PortForwarder
|
portForwarder PortForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tunnelUpPMTUDData struct {
|
||||||
|
// enabled is notably false if the user specifies a custom MTU.
|
||||||
|
enabled bool
|
||||||
|
// vpnType is used to find the maximum VPN header overhead.
|
||||||
|
// It can be [vpn.Wireguard] or [vpn.OpenVPN].
|
||||||
|
vpnType string
|
||||||
|
// network is used to find the network level header overhead.
|
||||||
|
// It can be [constants.UDP] or [constants.TCP].
|
||||||
|
network string
|
||||||
|
// icmpAddrs is the list of addresses to use for ICMP path MTU discovery.
|
||||||
|
// Each address should handle ICMP packets for PMTUD to work.
|
||||||
|
icmpAddrs []netip.Addr
|
||||||
|
// tcpAddrs is the list of addresses to use for TCP path MTU discovery.
|
||||||
|
// Each address should have a listening TCP server on the port specified.
|
||||||
|
tcpAddrs []netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||||
l.client.CloseIdleConnections()
|
l.client.CloseIdleConnections()
|
||||||
|
|
||||||
@@ -39,11 +54,14 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
if data.pmtud.enabled {
|
||||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||||
l.netLinker, l.routing, mtuLogger)
|
err := updateToMaxMTU(ctx, data.vpnIntf, data.pmtud.vpnType,
|
||||||
if err != nil {
|
data.pmtud.network, data.pmtud.icmpAddrs, data.pmtud.tcpAddrs,
|
||||||
mtuLogger.Error(err.Error())
|
l.netLinker, l.routing, mtuLogger)
|
||||||
|
if err != nil {
|
||||||
|
mtuLogger.Error(err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
||||||
@@ -136,12 +154,11 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
|||||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
|
||||||
|
|
||||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||||
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
vpnType, network string, icmpAddrs []netip.Addr, tcpAddrs []netip.AddrPort,
|
||||||
|
netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||||
) error {
|
) error {
|
||||||
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
logger.Info("finding maximum MTU, this can take up to 6 seconds")
|
||||||
|
|
||||||
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -155,18 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
|||||||
|
|
||||||
originalMTU := link.MTU
|
originalMTU := link.MTU
|
||||||
|
|
||||||
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
vpnLinkMTU := pmtud.MaxTheoreticalVPNMTU(vpnType, network, vpnGatewayIP)
|
||||||
// protocol overhead, so start lower than 1500 according to the protocol used.
|
|
||||||
const physicalLinkMTU uint32 = 1500
|
|
||||||
vpnLinkMTU := physicalLinkMTU
|
|
||||||
switch vpnType {
|
|
||||||
case "wireguard":
|
|
||||||
vpnLinkMTU -= 60 // Wireguard overhead
|
|
||||||
case "openvpn":
|
|
||||||
vpnLinkMTU -= 41 // OpenVPN overhead
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||||
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||||
@@ -178,16 +184,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const pingTimeout = time.Second
|
const pingTimeout = time.Second
|
||||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, icmpAddrs, tcpAddrs,
|
||||||
switch {
|
vpnLinkMTU, pingTimeout, logger)
|
||||||
case err == nil:
|
if err != nil {
|
||||||
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
|
||||||
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
|
||||||
vpnLinkMTU = originalMTU
|
vpnLinkMTU = originalMTU
|
||||||
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||||
vpnInterface, originalMTU, err)
|
vpnInterface, originalMTU, err)
|
||||||
default:
|
} else {
|
||||||
return fmt.Errorf("path MTU discovering: %w", err)
|
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
|
||||||
@@ -195,5 +199,33 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
|||||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = setTCPMSSOnVPNRoute(vpnInterface, vpnLinkMTU, routing, netlinker)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting safe TCP MSS for MTU %d: %w", vpnLinkMTU, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTCPMSSOnVPNRoute(vpnIntf string, mtu uint32,
|
||||||
|
routing Routing, netlinker NetLinker,
|
||||||
|
) error {
|
||||||
|
route, err := routing.VPNRoute(vpnIntf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting VPN route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHeaderLength := pconstants.IPv4HeaderLength
|
||||||
|
if route.Dst.Addr().Is6() {
|
||||||
|
ipHeaderLength = pconstants.IPv6HeaderLength
|
||||||
|
}
|
||||||
|
const mysteriousOverhead = 20 // most likely TCP options, such as the 12B of timestamps
|
||||||
|
overhead := ipHeaderLength + pconstants.BaseTCPHeaderLength + mysteriousOverhead
|
||||||
|
mss := mtu - overhead
|
||||||
|
route.AdvMSS = mss
|
||||||
|
err = netlinker.RouteReplace(route)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing VPN route with MSS changed to %d: %w", mss, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user