Compare commits

..

6 Commits

Author SHA1 Message Date
Quentin McGaw 7f22fb3276 fix(protonvpn): support port 51820 for UDP OpenVPN 2026-02-11 14:13:34 +00:00
Quentin McGaw 6909a0c123 fix(healthcheck): prevent race condition and fix #3096 (#3123) 2026-02-11 14:12:20 +00:00
Quentin McGaw 3e1f48932a fix(openvpn): only log openvpn version corresponding to OPENVPN_VERSION 2026-02-11 14:12:08 +00:00
Chris Duck 50744852c5 fix(protonvpn): update OpenVPN settings (#3120) 2026-02-11 14:11:57 +00:00
Quentin McGaw 09e52bc685 fix(httpproxy): remove info log when no Proxy-Authorization header is present 2026-02-11 14:11:46 +00:00
Quentin McGaw 857fe425ec fix(wireguard): fix detection of kernelspace wireguard 2026-02-11 14:11:36 +00:00
98 changed files with 8993 additions and 8961 deletions
+1 -4
View File
@@ -59,13 +59,10 @@ jobs:
- name: Run tests in test container
run: |
touch coverage.txt
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
docker run --rm --device /dev/net/tun \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
- name: Verify dev cross platform compatibility
run: docker build --target xcompile .
- name: Build final image
run: docker build -t final-image .
-4
View File
@@ -46,10 +46,6 @@ RUN git init && \
git diff --exit-code && \
rm -rf .git/
FROM --platform=${BUILDPLATFORM} base AS xcompile
RUN GOOS=darwin go build -o /dev/null ./...
RUN GOOS=windows go build -o /dev/null ./...
FROM --platform=${BUILDPLATFORM} base AS build
ARG TARGETPLATFORM
ARG VERSION=unknown
+12 -14
View File
@@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
@@ -394,7 +393,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
dnsLogger := logger.New(log.SetComponent("dns"))
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient, firewallConf,
dnsLooper, err := dns.NewLoop(allSettings.DNS, httpClient,
dnsLogger)
if err != nil {
return fmt.Errorf("creating DNS loop: %w", err)
@@ -554,26 +553,26 @@ type netLinker interface {
Router
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
IsWireguardSupported() bool
IsIPv6Supported() (ok bool, err error)
PatchLoggerLevel(level log.Level)
}
type Addresser interface {
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, addr netip.Prefix) error
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -581,12 +580,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) (err error)
LinkSetUp(linkIndex uint32) (err error)
LinkSetDown(linkIndex uint32) (err error)
LinkSetMTU(linkIndex, mtu uint32) error
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}
type clier interface {
+12 -11
View File
@@ -7,10 +7,8 @@ require (
github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/gosettings v0.4.4
@@ -21,11 +19,12 @@ require (
github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1
github.com/ulikunitz/xz v0.5.15
github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.49.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
golang.org/x/text v0.31.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0
@@ -39,12 +38,13 @@ require (
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -55,11 +55,12 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sync v0.19.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+24 -22
View File
@@ -13,8 +13,6 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
@@ -28,12 +26,10 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
@@ -51,8 +47,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
@@ -97,6 +93,10 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -140,10 +140,12 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -153,8 +155,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -162,8 +164,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -48,10 +48,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
}
if p.Name == providers.Mullvad && vpnType == vpn.OpenVPN {
warner.Warn("https://mullvad.net/en/blog/removing-openvpn-15th-january-2026")
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
if err != nil {
return fmt.Errorf("server selection: %w", err)
+2 -3
View File
@@ -45,8 +45,7 @@ type Wireguard struct {
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint32 `json:"mtu"`
MTU uint16 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -273,7 +272,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
-17
View File
@@ -1,17 +0,0 @@
package dns
import (
"context"
"net/netip"
)
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
type Firewall interface {
RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error)
}
+8
View File
@@ -0,0 +1,8 @@
package dns
type Logger interface {
Debug(s string)
Info(s string)
Warn(s string)
Error(s string)
}
+1 -3
View File
@@ -24,7 +24,6 @@ type Loop struct {
localResolvers []netip.Addr
resolvConf string
client *http.Client
firewall Firewall
logger Logger
userTrigger bool
start <-chan struct{}
@@ -40,7 +39,7 @@ type Loop struct {
const defaultBackoffTime = 10 * time.Second
func NewLoop(settings settings.DNS,
client *http.Client, firewall Firewall, logger Logger,
client *http.Client, logger Logger,
) (loop *Loop, err error) {
start := make(chan struct{})
running := make(chan models.LoopStatus)
@@ -65,7 +64,6 @@ func NewLoop(settings settings.DNS,
filter: filter,
resolvConf: "/etc/resolv.conf",
client: client,
firewall: firewall,
logger: logger,
userTrigger: true,
start: start,
+2 -9
View File
@@ -1,14 +1,13 @@
package dns
import (
"context"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/nameserver"
)
func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
func (l *Loop) useUnencryptedDNS(fallback bool) {
settings := l.GetSettings()
targetIP := settings.GetFirstPlaintextIPv4()
@@ -21,9 +20,8 @@ func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
const dialTimeout = 3 * time.Second
const defaultDNSPort = 53
addrPort := netip.AddrPortFrom(targetIP, defaultDNSPort)
settingsInternalDNS := nameserver.SettingsInternalDNS{
AddrPort: addrPort,
AddrPort: netip.AddrPortFrom(targetIP, defaultDNSPort),
Timeout: dialTimeout,
}
nameserver.UseDNSInternally(settingsInternalDNS)
@@ -36,9 +34,4 @@ func (l *Loop) useUnencryptedDNS(ctx context.Context, fallback bool) {
if err != nil {
l.logger.Error(err.Error())
}
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
if err != nil {
l.logger.Error("restricting plain DNS traffic to " + targetIP.String() + ": " + err.Error())
}
}
+5 -5
View File
@@ -24,7 +24,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
"and go through your container network DNS outside the VPN tunnel!")
} else {
const fallback = false
l.useUnencryptedDNS(ctx, fallback)
l.useUnencryptedDNS(fallback)
}
select {
@@ -56,7 +56,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if !errors.Is(err, errUpdateBlockLists) {
const fallback = true
l.useUnencryptedDNS(ctx, fallback)
l.useUnencryptedDNS(fallback)
}
l.logAndWait(ctx, err)
settings = l.GetSettings()
@@ -66,7 +66,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
settings = l.GetSettings()
if !*settings.KeepNameserver && !*settings.ServerEnabled {
const fallback = false
l.useUnencryptedDNS(ctx, fallback)
l.useUnencryptedDNS(fallback)
}
l.userTrigger = false
@@ -94,7 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
settings := l.GetSettings()
if !*settings.KeepNameserver && *settings.ServerEnabled {
const fallback = false
l.useUnencryptedDNS(ctx, fallback)
l.useUnencryptedDNS(fallback)
l.stopServer()
}
l.stopped <- struct{}{}
@@ -105,7 +105,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
case err := <-runError: // unexpected error
l.statusManager.SetStatus(constants.Crashed)
const fallback = true
l.useUnencryptedDNS(ctx, fallback)
l.useUnencryptedDNS(fallback)
l.logAndWait(ctx, err)
return false
}
+1 -7
View File
@@ -39,9 +39,8 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
// use internal DNS server
const defaultDNSPort = 53
addrPort := netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort)
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
AddrPort: addrPort,
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
})
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
IPs: []netip.Addr{settings.ServerAddress},
@@ -51,11 +50,6 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
l.logger.Error(err.Error())
}
err = l.firewall.RestrictOutputAddrPort(ctx, addrPort)
if err != nil {
l.logger.Error("restricting plain DNS traffic to " + addrPort.Addr().String() + ": " + err.Error())
}
err = check.WaitForDNS(ctx, check.Settings{})
if err != nil {
l.stopServer()
+2 -2
View File
@@ -69,8 +69,8 @@ func Test_deleteIPTablesRule(t *testing.T) {
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: parsing \"invalid\": " +
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
-2
View File
@@ -29,7 +29,6 @@ type Config struct {
outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
portRedirections portRedirections
outputAddrPort map[uint16]netip.Addr
stateMutex sync.Mutex
}
@@ -53,7 +52,6 @@ func NewConfig(ctx context.Context, logger Logger,
runner: runner,
logger: logger,
allowedInputPorts: make(map[uint16]map[string]struct{}),
outputAddrPort: make(map[uint16]netip.Addr),
ipTables: iptables,
ip6Tables: ip6tables,
customRulesPath: "/iptables/post-rules.txt",
-13
View File
@@ -2,7 +2,6 @@ package firewall
import (
"context"
"fmt"
)
func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error {
@@ -20,15 +19,3 @@ func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction st
}
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) runIPv4AndV6IptablesInstructions(ctx context.Context,
ipv4Instructions, ipv6Instructions []string,
) error {
if err := c.runIptablesInstructions(ctx, ipv4Instructions); err != nil {
return fmt.Errorf("running iptables instructions: %w", err)
}
if err := c.runIP6tablesInstructions(ctx, ipv6Instructions); err != nil {
return fmt.Errorf("running ip6tables instructions: %w", err)
}
return nil
}
+19 -111
View File
@@ -9,19 +9,9 @@ import (
"strings"
)
type operation uint8
const (
opNone operation = iota
opAppend
opDelete
opInsert
opReplace
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
operation operation
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
@@ -32,7 +22,6 @@ type iptablesInstruction struct {
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
lineNumber uint16 // for replace operation, the line number to replace
}
func (i *iptablesInstruction) setDefaults() {
@@ -71,58 +60,6 @@ func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (
}
}
func (i *iptablesInstruction) String() string {
var sb strings.Builder
if i.table != "" && i.table != "filter" {
sb.WriteString(fmt.Sprintf("-t %s ", i.table))
}
switch i.operation {
case opNone:
panic("no operation specified")
case opAppend:
sb.WriteString(fmt.Sprintf("--append %s ", i.chain))
case opDelete:
sb.WriteString(fmt.Sprintf("--delete %s ", i.chain))
case opInsert:
sb.WriteString(fmt.Sprintf("--insert %s ", i.chain))
case opReplace:
sb.WriteString(fmt.Sprintf("--replace %s %d ", i.chain, i.lineNumber))
}
if i.inputInterface != "" {
sb.WriteString(fmt.Sprintf("-i %s ", i.inputInterface))
}
if i.outputInterface != "" {
sb.WriteString(fmt.Sprintf("-o %s ", i.outputInterface))
}
if i.protocol != "" {
sb.WriteString(fmt.Sprintf("-p %s ", i.protocol))
}
if i.source.IsValid() {
sb.WriteString(fmt.Sprintf("-s %s ", i.source.String()))
}
if i.destination.IsValid() {
sb.WriteString(fmt.Sprintf("-d %s ", i.destination.String()))
}
if i.destinationPort != 0 {
sb.WriteString(fmt.Sprintf("--dport %d ", i.destinationPort))
}
if len(i.ctstate) > 0 {
sb.WriteString(fmt.Sprintf("--ctstate %s ", strings.Join(i.ctstate, ",")))
}
if len(i.toPorts) > 0 {
var portStrings []string
for _, port := range i.toPorts {
portStrings = append(portStrings, strconv.FormatUint(uint64(port), 10))
}
sb.WriteString(fmt.Sprintf("--to-ports %s ", strings.Join(portStrings, ",")))
}
if i.target != "" {
sb.WriteString(fmt.Sprintf("-j %s ", i.target))
}
return strings.TrimSpace(sb.String())
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
@@ -140,63 +77,34 @@ func parseIptablesInstruction(s string) (instruction iptablesInstruction, err er
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
i := 0
for i < len(fields) {
consumed, err := parseInstructionFlag(fields[i:], &instruction)
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
i += consumed
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (consumed int, err error) {
flag := fields[0]
// All flags use one value after the flag, except the following:
switch flag {
case "-R", "--replace":
const expected = 3
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
}
consumed = expected
default:
const expected = 2
if len(fields) < expected {
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
ErrIptablesCommandMalformed, flag)
}
consumed = expected
}
value := fields[1]
switch flag {
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.operation = opDelete
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.operation = opAppend
instruction.append = true
instruction.chain = value
case "-I", "--insert":
instruction.operation = opInsert
instruction.chain = value
case "-R", "--replace":
instruction.operation = opReplace
instruction.chain = value
const base, bits = 10, 16
n, err := strconv.ParseUint(fields[2], base, bits)
if err != nil {
return 0, fmt.Errorf("parsing line number for --replace operation: %w", err)
}
instruction.lineNumber = uint16(n)
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
@@ -209,18 +117,18 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing source IP CIDR: %w", err)
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return 0, fmt.Errorf("parsing destination IP CIDR: %w", err)
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing destination port: %w", err)
return fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
@@ -232,14 +140,14 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return 0, fmt.Errorf("parsing port redirection: %w", err)
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return 0, fmt.Errorf("%w: unknown flag %q", ErrIptablesCommandMalformed, flag)
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
}
return consumed, nil
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
+8 -8
View File
@@ -23,19 +23,19 @@ func Test_parseIptablesInstruction(t *testing.T) {
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown flag \"-x\"",
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
s: "-I INPUT",
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opInsert,
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
@@ -43,7 +43,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
operation: opAppend,
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
@@ -57,7 +57,7 @@ func Test_parseIptablesInstruction(t *testing.T) {
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
operation: opDelete,
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
-131
View File
@@ -3,7 +3,6 @@ package firewall
import (
"context"
"fmt"
"net/netip"
"strconv"
)
@@ -82,133 +81,3 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
return nil
}
// RestrictOutputAddrPort allows outgoing traffic to a specific IP and port for both tcp and udp,
// while blocking other tcp or udp traffic to that port going to other IP addresses, both IPv4 and IPv6.
// If the port was previously allowed for another IP address, that previous allowance will be removed.
// Giving an invalid address will remove any existing restrictions for the port specified.
func (c *Config) RestrictOutputAddrPort(ctx context.Context, addrPort netip.AddrPort) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
existingIP := c.outputAddrPort[addrPort.Port()]
switch {
case existingIP == addrPort.Addr():
return nil
case !addrPort.Addr().IsValid():
// invalid address, remove any existing rules for the port
return c.removeOutputAddrPortRestriction(ctx, existingIP, addrPort.Port())
case !existingIP.IsValid():
// no previous existing address for the port
return c.insertOutputAddrPortRestriction(ctx, addrPort)
default:
// existing rule in the same IP family or different family
return c.replaceOutputAddrPortRestriction(ctx, existingIP, addrPort)
}
}
func (c *Config) removeOutputAddrPortRestriction(ctx context.Context, existingIP netip.Addr, port uint16) (err error) {
commonInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -j DROP", port),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -j DROP", port),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--delete OUTPUT -p udp --dport %d -d %s -j ACCEPT", port, existingIP),
fmt.Sprintf("--delete OUTPUT -p tcp --dport %d -d %s -j ACCEPT", port, existingIP),
}
if existingIP.Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
delete(c.outputAddrPort, port)
return nil
}
func (c *Config) insertOutputAddrPortRestriction(ctx context.Context, addrPort netip.AddrPort) (err error) {
commonInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -j DROP", addrPort.Port()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -j DROP", addrPort.Port()),
}
ipv4Instructions := commonInstructions
ipv6Instructions := commonInstructions
familySpecificInstructions := []string{
fmt.Sprintf("--insert OUTPUT -p udp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
fmt.Sprintf("--insert OUTPUT -p tcp --dport %d -d %s -j ACCEPT", addrPort.Port(), addrPort.Addr()),
}
if addrPort.Addr().Is4() {
ipv4Instructions = append(ipv4Instructions, familySpecificInstructions...)
} else {
ipv6Instructions = append(ipv6Instructions, familySpecificInstructions...)
}
err = c.runIPv4AndV6IptablesInstructions(ctx, ipv4Instructions, ipv6Instructions)
if err != nil {
return err
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}
func (c *Config) replaceOutputAddrPortRestriction(ctx context.Context,
existingIP netip.Addr, addrPort netip.AddrPort,
) (err error) {
for _, protocol := range [...]string{"udp", "tcp"} {
switch {
case existingIP.Is4() && addrPort.Addr().Is4():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIptablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv4 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is6():
oldInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
newInstruction := fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.replaceIP6tablesRule(ctx, oldInstruction, newInstruction)
if err != nil {
return fmt.Errorf("replacing existing IPv6 rule: %w", err)
}
case existingIP.Is4() && addrPort.Addr().Is6():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv4 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv6 rule: %w", err)
}
case existingIP.Is6() && addrPort.Addr().Is4():
instruction := fmt.Sprintf("--delete OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), existingIP)
err = c.runIP6tablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("removing existing IPv6 rule: %w", err)
}
instruction = fmt.Sprintf("--insert OUTPUT -p %s --dport %d -d %s -j ACCEPT",
protocol, addrPort.Port(), addrPort.Addr())
err = c.runIptablesInstruction(ctx, instruction)
if err != nil {
return fmt.Errorf("inserting new IPv4 rule: %w", err)
}
}
}
c.outputAddrPort[addrPort.Port()] = addrPort.Addr()
return nil
}
-51
View File
@@ -1,51 +0,0 @@
package firewall
import (
"context"
"errors"
"fmt"
)
var errRuleNotFound = errors.New("rule not found")
func (c *Config) replaceIptablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
targetRule, err := parseIptablesInstruction(oldInstruction)
if err != nil {
return fmt.Errorf("parsing iptables command to replace: %w", err)
}
lineNumber, err := findLineNumber(ctx, c.ipTables, targetRule, c.runner, c.logger)
if err != nil {
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
} else if lineNumber == 0 {
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
}
parsed, err := parseIptablesInstruction(newInstruction)
if err != nil {
return fmt.Errorf("parsing replacement iptables command: %w", err)
}
parsed.operation = opReplace
parsed.lineNumber = lineNumber
return c.runIptablesInstruction(ctx, parsed.String())
}
func (c *Config) replaceIP6tablesRule(ctx context.Context, oldInstruction, newInstruction string) error {
targetRule, err := parseIptablesInstruction(oldInstruction)
if err != nil {
return fmt.Errorf("parsing iptables command to replace: %w", err)
}
lineNumber, err := findLineNumber(ctx, c.ip6Tables, targetRule, c.runner, c.logger)
if err != nil {
return fmt.Errorf("finding to-be-replaced chain rule line number: %w", err)
} else if lineNumber == 0 {
return fmt.Errorf("%w: matching to-be-replaced instruction %q", errRuleNotFound, oldInstruction)
}
parsed, err := parseIptablesInstruction(newInstruction)
if err != nil {
return fmt.Errorf("parsing replacement iptables command: %w", err)
}
parsed.operation = opReplace
parsed.lineNumber = lineNumber
return c.runIP6tablesInstruction(ctx, parsed.String())
}
@@ -1,5 +1,3 @@
//go:build !windows
package mod
import (
-7
View File
@@ -1,7 +0,0 @@
//go:build !linux
package mod
func Probe(moduleName string) error {
panic("not implemented")
}
+17 -59
View File
@@ -1,75 +1,33 @@
//go:build linux || darwin
package netlink
import (
"fmt"
"net"
"net/netip"
"github.com/jsimonetti/rtnetlink/rtnl"
"github.com/vishvananda/netlink"
)
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
ipPrefixes []netip.Prefix, err error,
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
) {
conn, err := rtnl.Dial(nil)
netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ifc := &net.Interface{
Index: int(linkIndex),
}
ipNets, err := conn.Addrs(ifc, int(family))
if err != nil {
return nil, fmt.Errorf("failed to list addresses: %w", err)
return nil, err
}
ipPrefixes = make([]netip.Prefix, len(ipNets))
for i := range ipNets {
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
}
return ipPrefixes, nil
return addresses, nil
}
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
conn, err := rtnl.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ipNet := netipPrefixToIPNet(prefix)
// Remove any address identical to the one we want to add
family := FamilyV4
if prefix.Addr().Is6() {
family = FamilyV6
}
ifc := &net.Interface{
Index: int(linkIndex),
}
addresses, err := conn.Addrs(ifc, int(family))
if err != nil {
return fmt.Errorf("listing addresses: %w", err)
}
for _, address := range addresses {
if address.IP.Equal(ipNet.IP) &&
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
err = conn.AddrDel(ifc, address)
if err != nil {
return fmt.Errorf("deleting address from interface: %w", err)
}
break
}
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
}
// Add the new address to the interface
err = conn.AddrAdd(ifc, ipNet)
if err != nil {
return fmt.Errorf("adding address to interface: %w", err)
}
return nil
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
}
+13
View File
@@ -0,0 +1,13 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
) {
panic("not implemented")
}
func (n *NetLink) AddrReplace(Link, Addr) error {
panic("not implemented")
}
-24
View File
@@ -36,30 +36,6 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
return netip.PrefixFrom(ip, bits)
}
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
if ip == nil || len(*ip) == 0 {
return netip.Prefix{}
}
var dstIP netip.Addr
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
dstIP = netip.AddrFrom4([4]byte(*ip))
} else {
dstIP = netip.AddrFrom16([16]byte(*ip))
}
return netip.PrefixFrom(dstIP, int(length))
}
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
if !prefix.IsValid() {
return nil, 0
}
prefixIP := prefix.Addr().Unmap()
ip = new(net.IP)
*ip = netipAddrToNetIP(prefixIP)
length = uint8(prefix.Bits()) //nolint:gosec
return ip, length
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
+7 -1
View File
@@ -4,7 +4,13 @@ import (
"fmt"
)
func FamilyToString(family uint8) string {
const (
FamilyAll = 0
FamilyV4 = 2
FamilyV6 = 10
)
func FamilyToString(family int) string {
switch family {
case FamilyAll:
return "all"
-9
View File
@@ -1,9 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
FamilyAll uint8 = unix.AF_UNSPEC
FamilyV4 uint8 = unix.AF_INET
FamilyV6 uint8 = unix.AF_INET6
)
-14
View File
@@ -1,30 +1,16 @@
package netlink
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/log"
)
func ptrTo[T any](v T) *T { return &v }
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
name := make([]byte, 8)
for i := range name {
name[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(name)
}
type noopLogger struct{}
func (l *noopLogger) Debug(_ string) {}
+1 -1
View File
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
return false, fmt.Errorf("finding link corresponding to route: %w", err)
}
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch {
case !sourceIsIPv6 && !destinationIsIPv6,
+76 -162
View File
@@ -1,191 +1,105 @@
//go:build linux || darwin
package netlink
import (
"errors"
"fmt"
"github.com/jsimonetti/rtnetlink"
)
type DeviceType uint16
type Link struct {
Index uint32
Name string
DeviceType DeviceType
VirtualType string
MTU uint32
}
import "github.com/vishvananda/netlink"
func (n *NetLink) LinkList() (links []Link, err error) {
conn, err := rtnetlink.Dial(nil)
netlinkLinks, err := netlink.LinkList()
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkMessages, err := conn.Link.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
return nil, err
}
links = make([]Link, len(linkMessages))
for i, message := range linkMessages {
virtualType := ""
if message.Attributes.Info != nil {
virtualType = message.Attributes.Info.Kind
}
links[i] = Link{
Index: message.Index,
Name: message.Attributes.Name,
DeviceType: DeviceType(message.Type),
VirtualType: virtualType,
MTU: message.Attributes.MTU,
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
}
return links, nil
}
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) {
links, err := n.LinkList()
netlinkLink, err := netlink.LinkByName(name)
if err != nil {
return Link{}, fmt.Errorf("listing links: %w", err)
return Link{}, err
}
for _, link := range links {
if link.Name == name {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
links, err := n.LinkList()
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
netlinkLink, err := netlink.LinkByIndex(index)
if err != nil {
return Link{}, fmt.Errorf("listing links: %w", err)
return Link{}, err
}
for _, link = range links {
if link.Index == index {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
conn, err := rtnetlink.Dial(nil)
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
if err != nil {
return 0, fmt.Errorf("dialing netlink: %w", err)
return 0, err
}
defer conn.Close()
return netlinkLink.Attrs().Index, nil
}
tx := &rtnetlink.LinkMessage{
Type: uint16(link.DeviceType),
Attributes: &rtnetlink.LinkAttributes{
MTU: link.MTU,
Name: link.Name,
func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU), //nolint:gosec
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
},
}
if link.VirtualType != "" {
tx.Attributes.Info = &rtnetlink.LinkInfo{
Kind: link.VirtualType,
}
}
err = conn.Link.New(tx)
if err != nil {
return 0, fmt.Errorf("creating new link: %w", err)
}
linkMessages, err := conn.Link.List()
if err != nil {
return 0, fmt.Errorf("listing links: %w", err)
}
for _, linkMessage := range linkMessages {
if linkMessage.Attributes.Name == link.Name {
return linkMessage.Index, nil
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Link.Delete(linkIndex)
}
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
rx, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
tx := &rtnetlink.LinkMessage{
Type: rx.Type,
Index: linkIndex,
Flags: iffUp,
Change: iffUp,
}
return conn.Link.Set(tx)
}
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkInfo, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
message := &rtnetlink.LinkMessage{
Type: linkInfo.Type,
Index: linkIndex,
Flags: 0,
Change: iffUp,
}
return conn.Link.Set(message)
}
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
message := &rtnetlink.LinkMessage{
Index: linkIndex,
Attributes: &rtnetlink.LinkAttributes{
MTU: mtu,
},
}
err = conn.Link.Set(message)
if err != nil {
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
mtu, linkIndex, err)
}
return nil
}
-11
View File
@@ -1,11 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
iffUp = unix.IFF_UP
)
-85
View File
@@ -1,85 +0,0 @@
//go:build linux
package netlink
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NetLink_LinkList(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
initialLinks, err := netlink.LinkList()
require.NoError(t, err)
require.NotEmpty(t, initialLinks)
loopbackFound := false
for _, link := range initialLinks {
if link.Name != "lo" {
continue
}
loopbackFound = true
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
break
}
assert.True(t, loopbackFound, "loopback interface not found")
testLink := Link{
Name: makeLinkName(),
// note if [Link.VirtualType] is set, [Link.DeviceType]
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
links, err := netlink.LinkList()
require.NoError(t, err)
testLink.Index = index
for _, link := range links {
if link.Name != testLink.Name {
continue
}
assert.Equal(t, testLink, link)
return
}
t.Errorf("created link %q not found", testLink.Name)
}
func Test_NetLink_LinkSetMTU(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
testLink := Link{
Name: makeLinkName(),
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
testLink.Index = index
err = netlink.LinkSetMTU(index, 1500)
require.NoError(t, err)
link, err := netlink.LinkByIndex(index)
require.NoError(t, err)
testLink.MTU = 1500
assert.Equal(t, testLink, link)
}
+31
View File
@@ -0,0 +1,31 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) LinkList() (links []Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByName(name string) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
panic("not implemented")
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkDel(link Link) (err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
panic("not implemented")
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
panic("not implemented")
}
-56
View File
@@ -1,56 +0,0 @@
//go:build !linux
package netlink
const (
// FamilyAll is a placeholder only and should not
// be used.
FamilyAll uint8 = iota
// FamilyV4 is a placeholder only and should not
// be used.
FamilyV4
// FamilyV6 is a placeholder only and should not
// be used.
FamilyV6
// DeviceTypeEthernet is a placeholder only and should not be used.
DeviceTypeEthernet DeviceType = 0
// DeviceTypeLoopback is a placeholder only and should not be used.
DeviceTypeLoopback DeviceType = 0
// DeviceTypeNone is a placeholder only and should not be used.
DeviceTypeNone DeviceType = 0
// iffUp is a placeholder only and should not be used.
iffUp = 0
// RouteTypeUnicast is a placeholder only and should not be used.
RouteTypeUnicast = 0
// ScopeUniverse is a placeholder only and should not be used.
ScopeUniverse = 0
// ProtoStatic is a placeholder only and should not be used.
ProtoStatic = 0
// FlagInvert is a placeholder only and should not be used.
FlagInvert = 0
// ActionToTable is a placeholder only and should not be used.
ActionToTable = 0
// rtTableCompat is a placeholder only and should not be used.
rtTableCompat = 0
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) IsWireguardSupported() (bool, error) {
panic("not implemented")
}
+48 -104
View File
@@ -1,125 +1,69 @@
//go:build linux || darwin
package netlink
import (
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
"github.com/vishvananda/netlink"
)
type Route struct {
LinkIndex uint32
Dst netip.Prefix
Src netip.Prefix
Gw netip.Addr
Priority uint32
Family uint8
Table uint32
Type uint8
Scope uint8
Proto uint8
}
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
// We set the filter to netlink.RT_FILTER_TABLE so that
// routes from all tables are listed, as long as the filter
// table is set to 0.
const filterMask = netlink.RT_FILTER_TABLE
// The filter is not left to `nil` otherwise non-main tables
// are ignored.
filter := &netlink.Route{}
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = message.Attributes.Table
}
r.LinkIndex = message.Attributes.OutIface
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
}
func (r Route) message() *rtnetlink.RouteMessage {
dst, dstLength := prefixToIPAndLength(r.Dst)
src, srcLength := prefixToIPAndLength(r.Src)
var table uint8
var extendedTable uint32
if r.Table <= uint32(^uint8(0)) {
table = uint8(r.Table)
} else {
table = rtTableCompat
extendedTable = r.Table
}
message := &rtnetlink.RouteMessage{
Family: r.Family,
DstLength: dstLength,
SrcLength: srcLength,
Table: table,
Type: r.Type,
Scope: r.Scope,
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Dst: *dst, // there should always be a dst for routes
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
},
}
if src != nil { // src is optional
message.Attributes.Src = *src
}
return message
}
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
conn, err := rtnetlink.Dial(nil)
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
routeMessages, err := conn.Route.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
return nil, err
}
routes = make([]Route, 0, len(routeMessages))
for _, routeMessage := range routeMessages {
if family != FamilyAll && routeMessage.Family != family {
continue
}
var route Route
route.fromMessage(routeMessage)
routes = append(routes, route)
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
}
return routes, nil
}
func (n *NetLink) RouteAdd(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Add(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
}
func (n *NetLink) RouteDel(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Delete(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
}
func (n *NetLink) RouteReplace(route Route) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Replace(route.message())
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
}
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
}
}
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
}
-11
View File
@@ -1,11 +0,0 @@
package netlink
import "golang.org/x/sys/unix"
const (
RouteTypeUnicast = unix.RTN_UNICAST
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
ProtoStatic = unix.RTPROT_STATIC
rtTableCompat = unix.RT_TABLE_COMPAT
)
+21
View File
@@ -0,0 +1,21 @@
//go:build !linux && !darwin
package netlink
func (n *NetLink) RouteList(family int) (
routes []Route, err error,
) {
panic("not implemented")
}
func (n *NetLink) RouteAdd(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteDel(route Route) error {
panic("not implemented")
}
func (n *NetLink) RouteReplace(route Route) error {
panic("not implemented")
}
+73 -78
View File
@@ -1,96 +1,91 @@
//go:build linux
package netlink
import (
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
"github.com/vishvananda/netlink"
)
type Rule struct {
Priority *uint32
Family uint8
Table uint32
Mark *uint32
Src netip.Prefix
Dst netip.Prefix
Flags uint32
Action uint8
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: 0,
}
}
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = *message.Attributes.Table
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Mark = message.Attributes.FwMark
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
r.Flags = message.Flags
r.Action = message.Action
netlinkRules, err := netlink.RuleList(family)
if err != nil {
return nil, err
}
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
}
return rules, nil
}
func (r Rule) message() *rtnetlink.RuleMessage {
src, srcLength := prefixToIPAndLength(r.Src)
dst, dstLength := prefixToIPAndLength(r.Dst)
message := &rtnetlink.RuleMessage{
Family: r.Family,
SrcLength: srcLength,
DstLength: dstLength,
Flags: r.Flags,
Action: r.Action,
Attributes: &rtnetlink.RuleAttributes{
Priority: r.Priority,
FwMark: r.Mark,
Src: src,
Dst: dst,
},
}
if r.Table <= uint32(^uint8(0)) {
message.Table = uint8(r.Table)
} else {
message.Table = rtTableCompat
message.Attributes.Table = &r.Table
}
return message
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(true, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
priority := ""
if r.Priority != nil {
priority = fmt.Sprintf(" %d", *r.Priority)
}
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
priority, from, to, r.Table)
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(false, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
}
func (r Rule) debugMessage(add bool) (debugMessage string) {
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
}
}
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
debugMessage = "ip"
switch r.Family {
switch rule.Family {
case FamilyV4:
debugMessage += " -f inet"
case FamilyV6:
debugMessage += " -f inet6"
default:
debugMessage += " -f " + fmt.Sprint(r.Family)
debugMessage += " -f " + fmt.Sprint(rule.Family)
}
debugMessage += " rule"
@@ -101,20 +96,20 @@ func (r Rule) debugMessage(add bool) (debugMessage string) {
debugMessage += " del"
}
if r.Src.IsValid() {
debugMessage += " from " + r.Src.String()
if rule.Src.IsValid() {
debugMessage += " from " + rule.Src.String()
}
if r.Dst.IsValid() {
debugMessage += " to " + r.Dst.String()
if rule.Dst.IsValid() {
debugMessage += " to " + rule.Dst.String()
}
if r.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(r.Table)
if rule.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(rule.Table)
}
if r.Priority != nil {
debugMessage += " pref " + fmt.Sprint(*r.Priority)
if rule.Priority != -1 {
debugMessage += " pref " + fmt.Sprint(rule.Priority)
}
return debugMessage
-69
View File
@@ -1,69 +0,0 @@
package netlink
import (
"fmt"
"github.com/jsimonetti/rtnetlink"
"golang.org/x/sys/unix"
)
const (
FlagInvert = unix.FIB_RULE_INVERT
ActionToTable = unix.FR_ACT_TO_TBL
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
n.debugLogger.Debug("ip -6 rule list")
case FamilyV4:
n.debugLogger.Debug("ip -4 rule list")
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ruleMessages, err := conn.Rule.List()
if err != nil {
return nil, err
}
rules = make([]Rule, 0, len(ruleMessages))
for _, message := range ruleMessages {
if family != FamilyAll && family != message.Family {
continue
}
var rule Rule
rule.fromMessage(message)
rules = append(rules, rule)
}
return rules, nil
}
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(true))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Add(rule.message())
}
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(rule.debugMessage(false))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Delete(rule.message())
}
+5 -5
View File
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_Rule_debugMessage(t *testing.T) {
func Test_ruleDbgMsg(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -15,7 +15,7 @@ func Test_Rule_debugMessage(t *testing.T) {
dbgMsg string
}{
"default values": {
dbgMsg: "ip -f 0 rule del",
dbgMsg: "ip -f 0 rule del pref 0",
},
"add rule": {
add: true,
@@ -24,7 +24,7 @@ func Test_Rule_debugMessage(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: ptrTo(uint32(101)),
Priority: 101,
},
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -34,7 +34,7 @@ func Test_Rule_debugMessage(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: ptrTo(uint32(101)),
Priority: 101,
},
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -44,7 +44,7 @@ func Test_Rule_debugMessage(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
dbgMsg := testCase.rule.debugMessage(testCase.add)
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
assert.Equal(t, testCase.dbgMsg, dbgMsg)
})
+19
View File
@@ -0,0 +1,19 @@
//go:build !linux
package netlink
func NewRule() Rule {
return Rule{}
}
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
panic("not implemented")
}
func (n *NetLink) RuleAdd(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
+58
View File
@@ -0,0 +1,58 @@
package netlink
import (
"fmt"
"net/netip"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
}
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
type Rule struct {
Priority int
Family int
Table int
Mark uint32
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
+37
View File
@@ -0,0 +1,37 @@
//go:build linux
package netlink
import (
"github.com/qdm12/gluetun/internal/mod"
"github.com/vishvananda/netlink"
)
func (n *NetLink) IsWireguardSupported() bool {
// Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels.
ok := hasWireguardFamily()
if ok {
return true
}
// Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err := mod.Probe("wireguard")
if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false
}
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module.
return hasWireguardFamily()
}
func hasWireguardFamily() bool {
_, err := netlink.GenlFamilyGet("wireguard")
return err == nil
}
-58
View File
@@ -1,58 +0,0 @@
package netlink
import (
"errors"
"fmt"
"os"
"github.com/mdlayher/genetlink"
"github.com/qdm12/gluetun/internal/mod"
)
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
} else if ok {
return true, nil
}
// Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err = mod.Probe("wireguard")
if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false, nil
}
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module.
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
}
return ok, nil
}
func hasWireguardFamily() (ok bool, err error) {
conn, err := genetlink.Dial(nil)
if err != nil {
return false, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
_, err = conn.GetFamily("wireguard")
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("getting wireguard family: %w", err)
}
return true, nil
}
+1 -4
View File
@@ -4,8 +4,6 @@ package netlink
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
@@ -14,8 +12,7 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
netLink := &NetLink{
debugLogger: &noopLogger{},
}
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
ok := netLink.IsWireguardSupported()
if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported")
} else {
@@ -0,0 +1,7 @@
//go:build !linux
package netlink
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
panic("not implemented")
}
+2 -1
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os/exec"
"syscall"
"github.com/qdm12/gluetun/internal/constants/openvpn"
)
@@ -32,7 +33,7 @@ func start(ctx context.Context, starter CmdStarter, version string, flags []stri
args := []string{"--config", configPath}
args = append(args, flags...)
cmd := exec.CommandContext(ctx, bin, args...)
setCmdSysProcAttr(cmd)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
return starter.Start(cmd)
}
-10
View File
@@ -1,10 +0,0 @@
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
-12
View File
@@ -1,12 +0,0 @@
//go:build !linux
package openvpn
import (
"os/exec"
"syscall"
)
func setCmdSysProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
-49
View File
@@ -1,49 +0,0 @@
package pmtud
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}
-83
View File
@@ -1,83 +0,0 @@
package pmtud
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU uint32) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
}
return nil
}
-10
View File
@@ -1,10 +0,0 @@
//go:build !linux && !windows
package pmtud
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr) (err error) {
return nil
}
-10
View File
@@ -1,10 +0,0 @@
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}
-13
View File
@@ -1,13 +0,0 @@
//go:build windows
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}
-29
View File
@@ -1,29 +0,0 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}
-7
View File
@@ -1,7 +0,0 @@
package pmtud
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}
-159
View File
@@ -1,159 +0,0 @@
package pmtud
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"syscall"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU uint32 = 68
icmpv4Protocol int = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}
return packetConn, nil
}
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop in case we read an echo reply for another ICMP request
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = uint32(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram
const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}
-122
View File
@@ -1,122 +0,0 @@
package pmtud
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
return packetConn, nil
}
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger,
) (mtu uint32, err error) {
if ip.Is4() {
panic("IP address is not v6")
}
conn, err := listenICMPv6(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop if we encounter another ICMP packet with an unknown id.
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
packetBytes = packetBytes[ipv6.HeaderLen:]
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = uint32(typedBody.MTU) //nolint:gosec
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
// Sanity checks
const truncatedBody = true
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
if err != nil {
return 0, fmt.Errorf("checking invoking message: %w", err)
}
return uint32(typedBody.MTU), nil //nolint:gosec
case *icmp.DstUnreach:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}
-58
View File
@@ -1,58 +0,0 @@
package pmtud
import (
cryptorand "crypto/rand"
"encoding/binary"
"fmt"
"math/rand/v2"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func buildMessageToSend(ipVersion string, mtu uint32) (id uint16, message *icmp.Message) {
var seed [32]byte
_, _ = cryptorand.Read(seed[:])
randomSource := rand.NewChaCha8(seed)
const uint16Bytes = 2
idBytes := make([]byte, uint16Bytes)
_, _ = randomSource.Read(idBytes)
id = binary.BigEndian.Uint16(idBytes)
var ipHeaderLength uint32
var icmpType icmp.Type
switch ipVersion {
case "v4":
ipHeaderLength = ipv4.HeaderLen
icmpType = ipv4.ICMPTypeEcho
case "v6":
ipHeaderLength = ipv6.HeaderLen
icmpType = ipv6.ICMPTypeEchoRequest
default:
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
}
const pingHeaderLength = 0 +
1 + // type
1 + // code
2 + // checksum
2 + // identifier
2 // sequence number
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
messageBodyData := make([]byte, pingBodyDataSize)
_, _ = randomSource.Read(messageBodyData)
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
message = &icmp.Message{
Type: icmpType, // echo request
Code: 0, // no code
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
Body: &icmp.Echo{
ID: int(id),
Seq: 0, // only one packet
Data: messageBodyData,
},
}
return id, message
}
-7
View File
@@ -1,7 +0,0 @@
package pmtud
type noopLogger struct{}
func (noopLogger) Debug(_ string) {}
func (noopLogger) Debugf(_ string, _ ...any) {}
func (noopLogger) Warnf(_ string, _ ...any) {}
-271
View File
@@ -1,271 +0,0 @@
package pmtud
import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU uint32, pingTimeout time.Duration, logger Logger) (
mtu uint32, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if pingTimeout == 0 {
pingTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu uint32
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU uint32, pingTimeout time.Duration,
logger Logger,
) (maxMTU uint32, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU uint32) (mtus []uint32) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]uint32, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]uint32, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, uint32(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}
-22
View File
@@ -1,22 +0,0 @@
//go:build integration
package pmtud
import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_PathMTUDiscover(t *testing.T) {
t.Parallel()
const physicalLinkMTU = 1500
const timeout = time.Second
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
physicalLinkMTU, timeout, nil)
require.NoError(t, err)
t.Log("MTU found:", mtu)
}
-55
View File
@@ -1,55 +0,0 @@
package pmtud
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_makeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
minMTU uint32
maxMTU uint32
mtus []uint32
}{
"0_0": {
mtus: []uint32{0},
},
"0_1": {
maxMTU: 1,
mtus: []uint32{0, 1},
},
"0_8": {
maxMTU: 8,
mtus: []uint32{0, 1, 2, 3, 4, 5, 6, 7, 8},
},
"0_12": {
maxMTU: 12,
mtus: []uint32{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
},
"0_80": {
maxMTU: 80,
mtus: []uint32{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
},
"0_100": {
maxMTU: 100,
mtus: []uint32{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
},
"1280_1500": {
minMTU: 1280,
maxMTU: 1500,
mtus: []uint32{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
}
+1 -1
View File
@@ -14,7 +14,7 @@ type Service interface {
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
}
type PortAllower interface {
+1 -1
View File
@@ -17,7 +17,7 @@ type PortAllower interface {
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
}
type Logger interface {
+3 -2
View File
@@ -6,6 +6,7 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
)
var ErrRouteDefaultNotFound = errors.New("default route not found")
@@ -14,7 +15,7 @@ type DefaultRoute struct {
NetInterface string
Gateway netip.Addr
AssignedIP netip.Addr
Family uint8
Family int
}
func (d DefaultRoute) String() string {
@@ -29,7 +30,7 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
}
for _, route := range routes {
if route.Table != tableMain {
if route.Table != unix.RT_TABLE_MAIN {
// ignore non-main table
continue
}
+4 -4
View File
@@ -8,8 +8,8 @@ import (
)
const (
inboundTable uint32 = 200
inboundPriority uint32 = 100
inboundTable = 200
inboundPriority = 100
)
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
return nil
}
func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []Defaul
return nil
}
func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
+2 -2
View File
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
func ipMatchesFamily(ip netip.Addr, family uint8) bool {
func ipMatchesFamily(ip netip.Addr, family int) bool {
return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6())
}
func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
+5 -4
View File
@@ -6,6 +6,7 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
)
var (
@@ -26,10 +27,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("listing links: %w", err)
}
localLinks := make(map[uint32]struct{})
localLinks := make(map[int]struct{})
for _, link := range links {
if link.DeviceType != netlink.DeviceTypeEthernet {
if link.EncapType != "ether" {
continue
}
@@ -47,7 +48,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
}
for _, route := range routes {
if route.Table != tableMain ||
if route.Table != unix.RT_TABLE_MAIN ||
(route.Gw.IsValid() && !route.Gw.IsUnspecified()) ||
(route.Dst.IsValid() && route.Dst.Addr().IsUnspecified()) {
continue
@@ -95,7 +96,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
// Local has higher priority then outbound(99) and inbound(100) as the
// local routes might be necessary to reach the outbound/inbound routes.
const localPriority uint32 = 98
const localPriority = 98
// Main table was setup correctly by Docker, just need to add rules to use it
src := netip.Prefix{}
+14 -14
View File
@@ -5,7 +5,6 @@
package routing
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -36,10 +35,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrList mocks base method.
func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
ret0, _ := ret[0].([]netip.Prefix)
ret0, _ := ret[0].([]netlink.Addr)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -51,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -65,10 +64,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(uint32)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -80,7 +79,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
}
// LinkByIndex mocks base method.
func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
ret0, _ := ret[0].(netlink.Link)
@@ -110,7 +109,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -139,7 +138,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -153,11 +152,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
}
// RuleList mocks base method.
func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleList", arg0)
ret0, _ := ret[0].([]netlink.Rule)
+2 -2
View File
@@ -9,8 +9,8 @@ import (
)
const (
outboundTable uint32 = 199
outboundPriority uint32 = 99
outboundTable = 199
outboundPriority = 99
)
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
+4 -17
View File
@@ -9,33 +9,25 @@ import (
)
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table uint32,
iface string, table int,
) error {
destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr)
r.logger.Debug("ip route replace " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(int(table)))
" table " + strconv.Itoa(table))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}
if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
@@ -46,29 +38,24 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
}
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table uint32,
iface string, table int,
) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr)
r.logger.Debug("ip route delete " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(int(table)))
" table " + strconv.Itoa(table))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
}
if err := r.netLinker.RouteDel(route); err != nil {
+10 -10
View File
@@ -15,20 +15,20 @@ type NetLinker interface {
}
type Addresser interface {
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, prefix netip.Prefix) error
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -36,11 +36,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(index uint32) (err error)
LinkSetUp(index uint32) (err error)
LinkSetDown(index uint32) (err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}
type Routing struct {
+13 -39
View File
@@ -7,19 +7,12 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -38,19 +31,12 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error
return nil
}
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -67,12 +53,10 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uin
return nil
}
// rulesAreEqual checks whether two rules are equal
// only according to src, dst, priority and table.
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
ptrsEqual(a.Priority, b.Priority) &&
a.Priority == b.Priority &&
a.Table == b.Table
}
@@ -86,13 +70,3 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+22 -30
View File
@@ -17,20 +17,14 @@ func makeNetipPrefix(n byte) netip.Prefix {
}
func makeIPRule(src, dst netip.Prefix,
table uint32, priority uint32,
table, priority int,
) netlink.Rule {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
return netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
}
func Test_Routing_addIPRule(t *testing.T) {
@@ -52,8 +46,8 @@ func Test_Routing_addIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table uint32
priority uint32
table int
priority int
ruleList ruleListCall
ruleAdd ruleAddCall
err error
@@ -155,8 +149,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table uint32
priority uint32
table int
priority int
ruleList ruleListCall
ruleDel ruleDelCall
err error
@@ -244,8 +238,6 @@ func Test_Routing_deleteIPRule(t *testing.T) {
}
}
func ptrTo[T any](v T) *T { return &v }
func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
@@ -261,13 +253,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
},
@@ -275,13 +267,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
},
@@ -289,13 +281,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(999)),
Priority: 999,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
},
@@ -303,13 +295,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Table: 102,
Priority: 100,
Table: 999,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
},
@@ -317,13 +309,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: ptrTo(uint32(100)),
Priority: 100,
Table: 101,
},
equal: true,
-8
View File
@@ -1,8 +0,0 @@
package routing
import "golang.org/x/sys/unix"
const (
tableMain = unix.RT_TABLE_MAIN
tableLocal = unix.RT_TABLE_LOCAL
)
-8
View File
@@ -1,8 +0,0 @@
//go:build !linux
package routing
const (
tableMain = 0
tableLocal = 0
)
+6 -4
View File
@@ -6,6 +6,7 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
)
var (
@@ -33,12 +34,13 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
return route.Gw, nil
case route.Dst.IsSingleIP() &&
route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
route.Table == tableLocal: // Wireguard
if route.Src.Addr().Is6() {
route.Dst.Addr().Compare(route.Src) == 0 &&
route.Table == unix.RT_TABLE_LOCAL: // Wireguard
route.Src = route.Src.Unmap()
if route.Src.Is6() {
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
}
bytes := route.Src.Addr().As4()
bytes := route.Src.As4()
// force last byte to 1 to get the VPN gateway IP
// This is not necessarily bullet proof but it seems to work.
bytes[3] = 1
+8216 -6376
View File
File diff suppressed because it is too large Load Diff
+7 -8
View File
@@ -57,15 +57,15 @@ type Storage interface {
}
type NetLinker interface {
AddrReplace(linkIndex uint32, addr netip.Prefix) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
Router
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
IsWireguardSupported() bool
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
@@ -77,11 +77,10 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) error
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkSetMTU(linkIndex, mtu uint32) error
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}
type DNSLoop interface {
-1
View File
@@ -47,7 +47,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue
}
tunnelUpData := tunnelUpData{
vpnType: settings.Type,
serverIP: connection.IP,
serverName: connection.ServerName,
canPortForward: connection.PortForward,
-77
View File
@@ -2,24 +2,16 @@ package vpn
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
)
type tunnelUpData struct {
// Healthcheck
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding
vpnIntf string
serverName string // used for PIA
@@ -39,13 +31,6 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
}
}
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
l.netLinker, l.routing, mtuLogger)
if err != nil {
mtuLogger.Error(err.Error())
}
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
icmpTargetIPs = []netip.Addr{data.serverIP}
@@ -135,65 +120,3 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
_, _ = l.ApplyStatus(ctx, constants.Stopped)
_, _ = l.ApplyStatus(ctx, constants.Running)
}
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
) error {
logger.Info("finding maximum MTU, this can take up to 4 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN gateway IP address: %w", err)
}
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
originalMTU := link.MTU
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU uint32 = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
switch {
case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU = originalMTU
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err)
default:
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
return nil
}
+12 -6
View File
@@ -3,20 +3,26 @@ package wireguard
import (
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addAddresses(linkIndex uint32,
func (w *Wireguard) addAddresses(link netlink.Link,
addresses []netip.Prefix,
) (err error) {
for _, address := range addresses {
if !*w.settings.IPv6 && address.Addr().Is6() {
for _, ipNet := range addresses {
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
continue
}
err = w.netlink.AddrReplace(linkIndex, address)
address := netlink.Addr{
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, linkIndex)
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Name)
}
}
+22 -21
View File
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -19,21 +20,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
linkIndex uint32
link netlink.Link
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
err error
}{
"success": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
link: netlink.Link{Type: "wireguard"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -44,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
},
"first add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@@ -58,18 +59,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
},
"second add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -78,11 +79,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
},
"ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
@@ -97,9 +98,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
wg := testCase.wgBuilder(ctrl, testCase.link)
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
err := wg.addAddresses(testCase.link, testCase.addrs)
if testCase.err != nil {
require.Error(t, err)
-50
View File
@@ -1,53 +1,3 @@
package wireguard
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func ptrTo[T any](x T) *T { return &x }
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
b := make([]byte, 8)
for i := range b {
b[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(b)
}
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table &&
a.Family == b.Family &&
a.Flags == b.Flags &&
a.Action == b.Action &&
ptrsEqual(a.Mark, b.Mark)
}
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if !a.IsValid() || !b.IsValid() {
return false
}
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+36 -34
View File
@@ -1,4 +1,4 @@
//go:build linux
//go:build netlink && linux
package wireguard
@@ -10,16 +10,13 @@ import (
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
type noopDebugLogger struct{}
func (n noopDebugLogger) Debug(_ string) {}
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
func (n noopDebugLogger) Info(_ string) {}
func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func (n noopDebugLogger) Debugf(format string, args ...any) {}
func (n noopDebugLogger) Patch(options ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
@@ -27,9 +24,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
link := netlink.Link{
DeviceType: netlink.DeviceTypeNone,
VirtualType: "bridge",
Name: makeLinkName(),
Type: "bridge",
Name: "test_8081",
}
// Remove any previously created test interface from a crashed/panic
// test or test suite run.
err := netlinker.LinkDel(link)
if err != nil && err.Error() != "invalid argument" {
require.NoError(t, err)
}
linkIndex, err := netlinker.LinkAdd(link)
@@ -37,7 +40,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
link.Index = linkIndex
defer func() {
err = netlinker.LinkDel(linkIndex)
err = netlinker.LinkDel(link)
assert.NoError(t, err)
}()
@@ -54,15 +57,17 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
}
const addIterations = 2 // initial + replace
for range addIterations {
err = wg.addAddresses(link.Index, addresses)
for i := 0; i < addIterations; i++ {
err = wg.addAddresses(link, addresses)
require.NoError(t, err)
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
require.NoError(t, err)
require.Equal(t, len(addresses), len(ipPrefixes))
for i, ipPrefix := range ipPrefixes {
assert.Equal(t, addresses[i], ipPrefix)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.Network)
assert.Equal(t, addresses[i], netlinkAddress.Network)
}
}
}
@@ -73,41 +78,38 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{
netlink: netlinker,
logger: &noopDebugLogger{},
}
// Unique combination for this test
const rulePriority uint32 = 10000
const firewallMark uint32 = 12345
const family = netlink.FamilyV4
rulePriority := 10000
const firewallMark = 999
const family = unix.AF_INET // ipv4
cleanup, err := wg.addRule(rulePriority,
firewallMark, family)
require.NoError(t, err)
t.Cleanup(func() {
defer func() {
err := cleanup()
assert.NoError(t, err)
})
}()
rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err)
expectedRule := netlink.Rule{
Priority: ptrTo(rulePriority),
Family: netlink.FamilyV4,
Table: firewallMark,
Mark: ptrTo(firewallMark),
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
var rule netlink.Rule
var ruleFound bool
for _, rule = range rules {
if rulesAreEqual(rule, expectedRule) {
if rule.Mark == firewallMark {
ruleFound = true
break
}
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority,
@@ -116,5 +118,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
_ = nilCleanup() // in case it succeeds
}
require.Error(t, err)
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
}
+8 -12
View File
@@ -1,23 +1,19 @@
package wireguard
import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
import "github.com/qdm12/gluetun/internal/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrReplace(linkIndex uint32, addr netip.Prefix) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
Router
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
IsWireguardSupported() bool
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
@@ -27,10 +23,10 @@ type Ruler interface {
}
type Linker interface {
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkDel(linkIndex uint32) error
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
}
+12 -13
View File
@@ -5,7 +5,6 @@
package wireguard
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -36,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -50,12 +49,11 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
func (m *MockNetLinker) IsWireguardSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
return ret0
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
@@ -65,10 +63,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(uint32)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -95,7 +93,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -124,7 +122,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -138,11 +136,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -166,7 +165,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
+7 -10
View File
@@ -8,11 +8,11 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
firewallMark uint32,
) (err error) {
for _, dst := range destinations {
err = w.addRoute(linkIndex, dst, firewallMark)
err = w.addRoute(link, dst, firewallMark)
if err == nil {
continue
}
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
return nil
}
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark uint32,
) (err error) {
family := netlink.FamilyV4
@@ -37,20 +37,17 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
family = netlink.FamilyV6
}
route := netlink.Route{
LinkIndex: linkIndex,
LinkIndex: link.Index,
Dst: dst,
Family: family,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
Table: int(firewallMark),
}
err = w.netlink.RouteAdd(route)
if err != nil {
return fmt.Errorf(
"adding route for link with index %d, destination %s and table %d: %w",
linkIndex, dst, firewallMark, err)
"adding route for link %s, destination %s and table %d: %w",
link.Name, dst, firewallMark, err)
}
return err
+10 -8
View File
@@ -23,36 +23,38 @@ func Test_Wireguard_addRoute(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst netip.Prefix
expectedRoute netlink.Route
routeAddErr error
err error
}{
"success": {
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
},
"route add error": {
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
routeAddErr: errDummy,
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
},
}
@@ -70,7 +72,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
if testCase.err != nil {
require.Error(t, err)
+8 -10
View File
@@ -7,17 +7,15 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family uint8,
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
family int,
) (cleanup func() error, err error) {
rule := netlink.Rule{
Priority: &rulePriority,
Family: family,
Table: firewallMark,
Mark: &firewallMark,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
rule := netlink.NewRule()
rule.Invert = true
rule.Priority = rulePriority
rule.Mark = firewallMark
rule.Table = int(firewallMark)
rule.Family = family
if err := w.netlink.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
+13 -15
View File
@@ -8,14 +8,15 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func Test_Wireguard_addRule(t *testing.T) {
t.Parallel()
const rulePriority uint32 = 987
const firewallMark uint32 = 456
const family = netlink.FamilyV4
const rulePriority = 987
const firewallMark = 456
const family = unix.AF_INET
errDummy := errors.New("dummy")
@@ -28,34 +29,31 @@ func Test_Wireguard_addRule(t *testing.T) {
}{
"success": {
expectedRule: netlink.Rule{
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
},
"rule add error": {
expectedRule: netlink.Rule{
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleAddErr: errDummy,
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
},
"rule delete error": {
expectedRule: netlink.Rule{
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
+40 -41
View File
@@ -7,14 +7,15 @@ import (
"net"
"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
@@ -33,11 +34,7 @@ var (
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
kernelSupported := w.netlink.IsWireguardSupported()
setupFunction := setupUserSpace
switch w.settings.Implementation {
@@ -70,14 +67,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger)
linkIndex, waitAndCleanup, err := setupFunction(ctx,
link, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
if err != nil {
waitError <- err
return
}
err = w.addAddresses(linkIndex, w.settings.Addresses)
err = w.addAddresses(link, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return
@@ -90,16 +87,17 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
err = w.netlink.LinkSetUp(linkIndex)
linkIndex, err := w.netlink.LinkSetUp(link)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(linkIndex)
return w.netlink.LinkSetDown(link)
})
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
@@ -108,7 +106,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
if *w.settings.IPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV6)
w.settings.FirewallMark, unix.AF_INET6)
if err != nil {
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
return
@@ -117,7 +115,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
}
ruleCleanup, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV4)
w.settings.FirewallMark, unix.AF_INET)
if err != nil {
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
return
@@ -135,38 +133,39 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
) {
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: mtu,
}
links, err := netLinker.LinkList()
if err != nil {
return 0, nil, fmt.Errorf("listing links: %w", err)
return link, nil, fmt.Errorf("listing links: %w", err)
}
// Cleanup any previous Wireguard interface with the same name
// See https://github.com/qdm12/gluetun/issues/1669
for _, link := range links {
if link.VirtualType == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link.Index)
if link.Type == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link)
if err != nil {
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
interfaceName, err)
}
}
}
link := netlink.Link{
VirtualType: "wireguard",
Name: interfaceName,
MTU: mtu,
}
linkIndex, err = netLinker.LinkAdd(link)
linkIndex, err := netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(linkIndex)
return netLinker.LinkDel(link)
})
waitAndCleanup = func() error {
@@ -175,35 +174,35 @@ func setupKernelSpace(ctx context.Context,
return ctx.Err()
}
return linkIndex, waitAndCleanup, nil
return link, waitAndCleanup, nil
}
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
) {
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err := netLinker.LinkByName(interfaceName)
link, err = netLinker.LinkByName(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link.Index)
return netLinker.LinkDel(link)
})
bind := conn.NewDefaultBind()
@@ -218,16 +217,16 @@ func setupUserSpace(ctx context.Context,
return nil
})
uapiFile, err := uapiOpen(interfaceName)
uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
@@ -252,7 +251,7 @@ func setupUserSpace(ctx context.Context,
return err
}
return link.Index, waitAndCleanup, nil
return link, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device *device.Device,
+2 -2
View File
@@ -38,10 +38,10 @@ type Settings struct {
FirewallMark uint32
// Maximum Transmission Unit (MTU) setting for the network interface.
// It defaults to device.DefaultMTU from wireguard-go which is 1420
MTU uint32
MTU uint16
// RulePriority is the priority for the rule created with the
// FirewallMark.
RulePriority uint32
RulePriority int
// IPv6 can bet set to true if IPv6 should be handled.
// It defaults to false if left unset.
IPv6 *bool
-16
View File
@@ -1,16 +0,0 @@
package wireguard
import (
"net"
"os"
"golang.zx2c4.com/wireguard/ipc"
)
func uapiOpen(name string) (*os.File, error) {
return ipc.UAPIOpen(name)
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, uapiFile)
}
@@ -1,16 +0,0 @@
//go:build !linux
package wireguard
import (
"net"
"os"
)
func uapiOpen(name string) (*os.File, error) {
panic("not implemented")
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
panic("not implemented")
}