diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a590601c..8f4be3d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,7 +59,7 @@ jobs: - name: Run tests in test container run: | touch coverage.txt - docker run --rm --device /dev/net/tun \ + docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \ -v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \ test-container diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index f08cf8f1..67da5306 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -6,6 +6,7 @@ import ( "fmt" "io/fs" "net/http" + "net/netip" "os" "os/exec" "os/signal" @@ -549,26 +550,26 @@ type netLinker interface { Router Ruler Linker - IsWireguardSupported() bool + IsWireguardSupported() (ok bool, err error) IsIPv6Supported() (ok bool, err error) PatchLoggerLevel(level log.Level) } type Addresser interface { - AddrList(link netlink.Link, family int) ( - addresses []netlink.Addr, err error) - AddrReplace(link netlink.Link, addr netlink.Addr) error + AddrList(linkIndex uint32, family uint8) ( + addresses []netip.Prefix, err error) + AddrReplace(linkIndex uint32, addr netip.Prefix) error } type Router interface { - RouteList(family int) (routes []netlink.Route, err error) + RouteList(family uint8) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error RouteDel(route netlink.Route) error RouteReplace(route netlink.Route) error } type Ruler interface { - RuleList(family int) (rules []netlink.Rule, err error) + RuleList(family uint8) (rules []netlink.Rule, err error) RuleAdd(rule netlink.Rule) error RuleDel(rule netlink.Rule) error } @@ -576,12 +577,12 @@ type Ruler interface { type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkByIndex(index int) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (linkIndex int, err error) - LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (linkIndex int, err error) - LinkSetDown(link netlink.Link) (err error) - LinkSetMTU(link netlink.Link, mtu uint32) error + LinkByIndex(index uint32) (link netlink.Link, err error) + LinkAdd(link netlink.Link) (linkIndex uint32, err error) + LinkDel(linkIndex uint32) (err error) + LinkSetUp(linkIndex uint32) (err error) + LinkSetDown(linkIndex uint32) (err error) + LinkSetMTU(linkIndex, mtu uint32) error } type clier interface { diff --git a/go.mod b/go.mod index 1773bf34..8190c32c 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/breml/rootcerts v0.3.3 github.com/fatih/color v1.18.0 github.com/golang/mock v1.6.0 + github.com/jsimonetti/rtnetlink v1.4.2 github.com/klauspost/compress v1.18.1 github.com/klauspost/pgzip v1.2.6 + github.com/mdlayher/genetlink v1.3.2 github.com/pelletier/go-toml/v2 v2.2.4 github.com/qdm12/dns/v2 v2.0.0-rc10 github.com/qdm12/gosettings v0.4.4 @@ -19,12 +21,11 @@ require ( github.com/qdm12/ss-server v0.6.0 github.com/stretchr/testify v1.11.1 github.com/ulikunitz/xz v0.5.15 - github.com/vishvananda/netlink v1.3.1 github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c - golang.org/x/net v0.47.0 - golang.org/x/sys v0.38.0 - golang.org/x/text v0.31.0 + golang.org/x/net v0.49.0 + golang.org/x/sys v0.40.0 + golang.org/x/text v0.33.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 gopkg.in/ini.v1 v1.67.0 @@ -38,13 +39,12 @@ require ( github.com/cloudflare/circl v1.6.1 // indirect github.com/cronokirby/saferith v0.33.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect - github.com/mdlayher/socket v0.4.1 // indirect + github.com/mdlayher/socket v0.5.1 // indirect github.com/miekg/dns v1.1.62 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -55,12 +55,11 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect - github.com/vishvananda/netns v0.0.5 // indirect - golang.org/x/crypto v0.45.0 // indirect - golang.org/x/mod v0.29.0 // indirect - golang.org/x/sync v0.18.0 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.38.0 // indirect + golang.org/x/tools v0.40.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/protobuf v1.35.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index b45bfb56..680da7ae 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4= +github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM= github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I= github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= @@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90= +github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= @@ -47,8 +51,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= @@ -93,10 +97,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= -github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= -github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= -github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= -golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -140,12 +140,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -155,8 +153,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -164,8 +162,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= -golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index 3edf133d..bd76c096 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -46,7 +46,7 @@ type Wireguard struct { // investigation in the issue: // https://github.com/qdm12/gluetun/issues/2533. // Note this should now be replaced with the PMTUD feature. - MTU uint16 `json:"mtu"` + MTU uint32 `json:"mtu"` // Implementation is the Wireguard implementation to use. // It can be "auto", "userspace" or "kernelspace". // It defaults to "auto" and cannot be the empty string @@ -273,7 +273,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) { return err } - mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU") + mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU") if err != nil { return err } else if mtuPtr != nil { diff --git a/internal/netlink/address.go b/internal/netlink/address.go index 0313946c..932d2f6f 100644 --- a/internal/netlink/address.go +++ b/internal/netlink/address.go @@ -1,31 +1,75 @@ package netlink import ( - "github.com/vishvananda/netlink" + "fmt" + "net" + "net/netip" + + "github.com/jsimonetti/rtnetlink/rtnl" ) -func (n *NetLink) AddrList(link Link, family int) ( - addresses []Addr, err error, +func (n *NetLink) AddrList(linkIndex uint32, family uint8) ( + ipPrefixes []netip.Prefix, err error, ) { - netlinkLink := linkToNetlinkLink(&link) - netlinkAddresses, err := netlink.AddrList(netlinkLink, family) + conn, err := rtnl.Dial(nil) if err != nil { - return nil, err + return nil, fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + ifc := &net.Interface{ + Index: int(linkIndex), + } + ipNets, err := conn.Addrs(ifc, int(family)) + if err != nil { + return nil, fmt.Errorf("failed to list addresses: %w", err) } - addresses = make([]Addr, len(netlinkAddresses)) - for i := range netlinkAddresses { - addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet) + ipPrefixes = make([]netip.Prefix, len(ipNets)) + for i := range ipNets { + ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i]) } - return addresses, nil + return ipPrefixes, nil } -func (n *NetLink) AddrReplace(link Link, addr Addr) error { - netlinkLink := linkToNetlinkLink(&link) - netlinkAddress := netlink.Addr{ - IPNet: netipPrefixToIPNet(addr.Network), +func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error { + conn, err := rtnl.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + ipNet := netipPrefixToIPNet(prefix) + + // Remove any address identical to the one we want to add + family := FamilyV4 + if prefix.Addr().Is6() { + family = FamilyV6 + } + ifc := &net.Interface{ + Index: int(linkIndex), + } + addresses, err := conn.Addrs(ifc, int(family)) + if err != nil { + return fmt.Errorf("listing addresses: %w", err) + } + for _, address := range addresses { + if address.IP.Equal(ipNet.IP) && + net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() { + err = conn.AddrDel(ifc, address) + if err != nil { + return fmt.Errorf("deleting address from interface: %w", err) + } + break + } } - return netlink.AddrReplace(netlinkLink, &netlinkAddress) + // Add the new address to the interface + err = conn.AddrAdd(ifc, ipNet) + if err != nil { + return fmt.Errorf("adding address to interface: %w", err) + } + + return nil } diff --git a/internal/netlink/conversion.go b/internal/netlink/conversion.go index 55910fc4..fed47e60 100644 --- a/internal/netlink/conversion.go +++ b/internal/netlink/conversion.go @@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) { return netip.PrefixFrom(ip, bits) } +func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix { + if ip == nil || len(*ip) == 0 { + return netip.Prefix{} + } + var dstIP netip.Addr + if ipv4 := ip.To4(); ipv4 != nil { // IPv6 + dstIP = netip.AddrFrom4([4]byte(*ip)) + } else { + dstIP = netip.AddrFrom16([16]byte(*ip)) + } + return netip.PrefixFrom(dstIP, int(length)) +} + +func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) { + if !prefix.IsValid() { + return nil, 0 + } + prefixIP := prefix.Addr().Unmap() + ip = new(net.IP) + *ip = netipAddrToNetIP(prefixIP) + length = uint8(prefix.Bits()) //nolint:gosec + return ip, length +} + func netipAddrToNetIP(address netip.Addr) (ip net.IP) { switch { case !address.IsValid(): diff --git a/internal/netlink/family.go b/internal/netlink/family.go index 9340ea7e..edd68afd 100644 --- a/internal/netlink/family.go +++ b/internal/netlink/family.go @@ -4,7 +4,7 @@ import ( "fmt" ) -func FamilyToString(family int) string { +func FamilyToString(family uint8) string { switch family { case FamilyAll: return "all" diff --git a/internal/netlink/family_linux.go b/internal/netlink/family_linux.go index 7410898d..6367a8da 100644 --- a/internal/netlink/family_linux.go +++ b/internal/netlink/family_linux.go @@ -3,7 +3,7 @@ package netlink import "golang.org/x/sys/unix" const ( - FamilyAll = unix.AF_UNSPEC - FamilyV4 = unix.AF_INET - FamilyV6 = unix.AF_INET6 + FamilyAll uint8 = unix.AF_UNSPEC + FamilyV4 uint8 = unix.AF_INET + FamilyV6 uint8 = unix.AF_INET6 ) diff --git a/internal/netlink/helpers_test.go b/internal/netlink/helpers_test.go index ecae2fa4..370253cb 100644 --- a/internal/netlink/helpers_test.go +++ b/internal/netlink/helpers_test.go @@ -1,16 +1,30 @@ package netlink import ( + "math/rand/v2" "net/netip" "github.com/qdm12/log" ) +func ptrTo[T any](v T) *T { return &v } + func makeNetipPrefix(n byte) netip.Prefix { const bits = 24 return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) } +var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals + +func makeLinkName() string { + const alphabet = "abcdefghijklmnopqrstuvwxyz" + name := make([]byte, 8) + for i := range name { + name[i] = alphabet[rng.IntN(len(alphabet))] + } + return "test" + string(name) +} + type noopLogger struct{} func (l *noopLogger) Debug(_ string) {} diff --git a/internal/netlink/ipv6.go b/internal/netlink/ipv6.go index d8eff77e..bded7dde 100644 --- a/internal/netlink/ipv6.go +++ b/internal/netlink/ipv6.go @@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) { return false, fmt.Errorf("finding link corresponding to route: %w", err) } - sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6() + sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6() destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6() switch { case !sourceIsIPv6 && !destinationIsIPv6, diff --git a/internal/netlink/link.go b/internal/netlink/link.go index 549fe8f1..5e103747 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -1,107 +1,191 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "errors" + "fmt" + + "github.com/jsimonetti/rtnetlink" +) + +type DeviceType uint16 + +type Link struct { + Index uint32 + Name string + DeviceType DeviceType + VirtualType string + MTU uint32 +} func (n *NetLink) LinkList() (links []Link, err error) { - netlinkLinks, err := netlink.LinkList() + conn, err := rtnetlink.Dial(nil) if err != nil { - return nil, err + return nil, fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + linkMessages, err := conn.Link.List() + if err != nil { + return nil, fmt.Errorf("listing interfaces: %w", err) } - links = make([]Link, len(netlinkLinks)) - for i := range netlinkLinks { - links[i] = netlinkLinkToLink(netlinkLinks[i]) + links = make([]Link, len(linkMessages)) + for i, message := range linkMessages { + virtualType := "" + if message.Attributes.Info != nil { + virtualType = message.Attributes.Info.Kind + } + links[i] = Link{ + Index: message.Index, + Name: message.Attributes.Name, + DeviceType: DeviceType(message.Type), + VirtualType: virtualType, + MTU: message.Attributes.MTU, + } } return links, nil } +var ErrLinkNotFound = errors.New("link not found") + func (n *NetLink) LinkByName(name string) (link Link, err error) { - netlinkLink, err := netlink.LinkByName(name) + links, err := n.LinkList() if err != nil { - return Link{}, err + return Link{}, fmt.Errorf("listing links: %w", err) } - return netlinkLinkToLink(netlinkLink), nil + for _, link := range links { + if link.Name == name { + return link, nil + } + } + + return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name) } -func (n *NetLink) LinkByIndex(index int) (link Link, err error) { - netlinkLink, err := netlink.LinkByIndex(index) +func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) { + links, err := n.LinkList() if err != nil { - return Link{}, err + return Link{}, fmt.Errorf("listing links: %w", err) } - return netlinkLinkToLink(netlinkLink), nil + for _, link = range links { + if link.Index == index { + return link, nil + } + } + + return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index) } -func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) { - netlinkLink := linkToNetlinkLink(&link) - err = netlink.LinkAdd(netlinkLink) +func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) { + conn, err := rtnetlink.Dial(nil) if err != nil { - return 0, err + return 0, fmt.Errorf("dialing netlink: %w", err) } - return netlinkLink.Attrs().Index, nil -} + defer conn.Close() -func (n *NetLink) LinkDel(link Link) (err error) { - return netlink.LinkDel(linkToNetlinkLink(&link)) -} - -func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) { - netlinkLink := linkToNetlinkLink(&link) - err = netlink.LinkSetUp(netlinkLink) - if err != nil { - return 0, err - } - return netlinkLink.Attrs().Index, nil -} - -func (n *NetLink) LinkSetDown(link Link) (err error) { - return netlink.LinkSetDown(linkToNetlinkLink(&link)) -} - -func (n *NetLink) LinkSetMTU(link Link, mtu uint32) error { - return netlink.LinkSetMTU(linkToNetlinkLink(&link), int(mtu)) -} - -type netlinkLinkImpl struct { - attrs *netlink.LinkAttrs - linkType string -} - -func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs { - return n.attrs -} - -func (n *netlinkLinkImpl) Type() string { - return n.linkType -} - -func netlinkLinkToLink(netlinkLink netlink.Link) Link { - attributes := netlinkLink.Attrs() - return Link{ - Type: netlinkLink.Type(), - Name: attributes.Name, - Index: attributes.Index, - EncapType: attributes.EncapType, - MTU: uint16(attributes.MTU), //nolint:gosec - } -} - -// Warning: we must return `netlink.Link` and not `netlinkLinkImpl` -// so that the vishvananda/netlink package can compare the returned -// value against an untyped nil. -func linkToNetlinkLink(link *Link) netlink.Link { - if link == nil { - return nil - } - return &netlinkLinkImpl{ - linkType: link.Type, - attrs: &netlink.LinkAttrs{ - Name: link.Name, - Index: link.Index, - EncapType: link.EncapType, - MTU: int(link.MTU), + tx := &rtnetlink.LinkMessage{ + Type: uint16(link.DeviceType), + Attributes: &rtnetlink.LinkAttributes{ + MTU: link.MTU, + Name: link.Name, }, } + if link.VirtualType != "" { + tx.Attributes.Info = &rtnetlink.LinkInfo{ + Kind: link.VirtualType, + } + } + + err = conn.Link.New(tx) + if err != nil { + return 0, fmt.Errorf("creating new link: %w", err) + } + + linkMessages, err := conn.Link.List() + if err != nil { + return 0, fmt.Errorf("listing links: %w", err) + } + for _, linkMessage := range linkMessages { + if linkMessage.Attributes.Name == link.Name { + return linkMessage.Index, nil + } + } + + return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name) +} + +func (n *NetLink) LinkDel(linkIndex uint32) (err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + return conn.Link.Delete(linkIndex) +} + +func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + rx, err := conn.Link.Get(linkIndex) + if err != nil { + return fmt.Errorf("getting link: %w", err) + } + tx := &rtnetlink.LinkMessage{ + Type: rx.Type, + Index: linkIndex, + Flags: iffUp, + Change: iffUp, + } + return conn.Link.Set(tx) +} + +func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + linkInfo, err := conn.Link.Get(linkIndex) + if err != nil { + return fmt.Errorf("getting link: %w", err) + } + message := &rtnetlink.LinkMessage{ + Type: linkInfo.Type, + Index: linkIndex, + Flags: 0, + Change: iffUp, + } + return conn.Link.Set(message) +} + +func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + message := &rtnetlink.LinkMessage{ + Index: linkIndex, + Attributes: &rtnetlink.LinkAttributes{ + MTU: mtu, + }, + } + + err = conn.Link.Set(message) + if err != nil { + return fmt.Errorf("setting MTU to %d for link at index %d: %w", + mtu, linkIndex, err) + } + + return nil } diff --git a/internal/netlink/link_linux.go b/internal/netlink/link_linux.go new file mode 100644 index 00000000..1d0ef84a --- /dev/null +++ b/internal/netlink/link_linux.go @@ -0,0 +1,11 @@ +package netlink + +import "golang.org/x/sys/unix" + +const ( + DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER + DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK + DeviceTypeNone DeviceType = unix.ARPHRD_NONE + + iffUp = unix.IFF_UP +) diff --git a/internal/netlink/link_test.go b/internal/netlink/link_test.go new file mode 100644 index 00000000..aa8951c3 --- /dev/null +++ b/internal/netlink/link_test.go @@ -0,0 +1,85 @@ +//go:build linux + +package netlink + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NetLink_LinkList(t *testing.T) { + t.Parallel() + + netlink := &NetLink{} + + initialLinks, err := netlink.LinkList() + require.NoError(t, err) + require.NotEmpty(t, initialLinks) + + loopbackFound := false + for _, link := range initialLinks { + if link.Name != "lo" { + continue + } + loopbackFound = true + assert.Equal(t, DeviceTypeLoopback, link.DeviceType) + break + } + assert.True(t, loopbackFound, "loopback interface not found") + + testLink := Link{ + Name: makeLinkName(), + // note if [Link.VirtualType] is set, [Link.DeviceType] + // is ignored and gets set to [DeviceTypeNone] in LinkAdd. + DeviceType: DeviceTypeNone, + VirtualType: "wireguard", + MTU: 1420, + } + index, err := netlink.LinkAdd(testLink) + require.NoError(t, err) + t.Cleanup(func() { + _ = netlink.LinkDel(index) + }) + + links, err := netlink.LinkList() + require.NoError(t, err) + + testLink.Index = index + for _, link := range links { + if link.Name != testLink.Name { + continue + } + assert.Equal(t, testLink, link) + return + } + t.Errorf("created link %q not found", testLink.Name) +} + +func Test_NetLink_LinkSetMTU(t *testing.T) { + t.Parallel() + + netlink := &NetLink{} + + testLink := Link{ + Name: makeLinkName(), + DeviceType: DeviceTypeNone, + VirtualType: "wireguard", + MTU: 1420, + } + index, err := netlink.LinkAdd(testLink) + require.NoError(t, err) + t.Cleanup(func() { + _ = netlink.LinkDel(index) + }) + testLink.Index = index + + err = netlink.LinkSetMTU(index, 1500) + require.NoError(t, err) + + link, err := netlink.LinkByIndex(index) + require.NoError(t, err) + testLink.MTU = 1500 + assert.Equal(t, testLink, link) +} diff --git a/internal/netlink/netlink_unspecified.go b/internal/netlink/netlink_unspecified.go index cf03b9c9..62b27595 100644 --- a/internal/netlink/netlink_unspecified.go +++ b/internal/netlink/netlink_unspecified.go @@ -5,16 +5,41 @@ package netlink const ( // FamilyAll is a placeholder only and should not // be used. - FamilyAll = iota + FamilyAll uint8 = iota // FamilyV4 is a placeholder only and should not // be used. FamilyV4 // FamilyV6 is a placeholder only and should not // be used. FamilyV6 + + // DeviceTypeEthernet is a placeholder only and should not be used. + DeviceTypeEthernet DeviceType = 0 + // DeviceTypeLoopback is a placeholder only and should not be used. + DeviceTypeLoopback DeviceType = 0 + // DeviceTypeNone is a placeholder only and should not be used. + DeviceTypeNone DeviceType = 0 + + // iffUp is a placeholder only and should not be used. + iffUp = 0 + + // RouteTypeUnicast is a placeholder only and should not be used. + RouteTypeUnicast = 0 + // ScopeUniverse is a placeholder only and should not be used. + ScopeUniverse = 0 + // ProtoStatic is a placeholder only and should not be used. + ProtoStatic = 0 + + // FlagInvert is a placeholder only and should not be used. + FlagInvert = 0 + // ActionToTable is a placeholder only and should not be used. + ActionToTable = 0 + + // rtTableCompat is a placeholder only and should not be used. + rtTableCompat = 0 ) -func (n *NetLink) RuleList(family int) (rules []Rule, err error) { +func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) { panic("not implemented") } @@ -26,6 +51,6 @@ func (n *NetLink) RuleDel(rule Rule) error { panic("not implemented") } -func (n *NetLink) IsWireguardSupported() bool { +func (n *NetLink) IsWireguardSupported() (bool, error) { panic("not implemented") } diff --git a/internal/netlink/route.go b/internal/netlink/route.go index 59c2045e..5fd2774e 100644 --- a/internal/netlink/route.go +++ b/internal/netlink/route.go @@ -1,67 +1,125 @@ package netlink import ( - "github.com/vishvananda/netlink" + "fmt" + "net/netip" + + "github.com/jsimonetti/rtnetlink" ) -func (n *NetLink) RouteList(family int) (routes []Route, err error) { - // We set the filter to netlink.RT_FILTER_TABLE so that - // routes from all tables are listed, as long as the filter - // table is set to 0. - const filterMask = netlink.RT_FILTER_TABLE - // The filter is not left to `nil` otherwise non-main tables - // are ignored. - filter := &netlink.Route{} +type Route struct { + LinkIndex uint32 + Dst netip.Prefix + Src netip.Prefix + Gw netip.Addr + Priority uint32 + Family uint8 + Table uint32 + Type uint8 + Scope uint8 + Proto uint8 +} - netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask) +func (r *Route) fromMessage(message rtnetlink.RouteMessage) { + table := uint32(message.Table) + if table == 0 || table == rtTableCompat { + table = message.Attributes.Table + } + r.LinkIndex = message.Attributes.OutIface + r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength) + r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength) + r.Gw = netIPToNetipAddress(message.Attributes.Gateway) + r.Priority = message.Attributes.Priority + r.Family = message.Family + r.Table = table + r.Type = message.Type + r.Scope = message.Scope + r.Proto = message.Protocol +} + +func (r Route) message() *rtnetlink.RouteMessage { + dst, dstLength := prefixToIPAndLength(r.Dst) + src, srcLength := prefixToIPAndLength(r.Src) + var table uint8 + var extendedTable uint32 + if r.Table <= uint32(^uint8(0)) { + table = uint8(r.Table) + } else { + table = rtTableCompat + extendedTable = r.Table + } + message := &rtnetlink.RouteMessage{ + Family: r.Family, + DstLength: dstLength, + SrcLength: srcLength, + Table: table, + Type: r.Type, + Scope: r.Scope, + Protocol: r.Proto, + Attributes: rtnetlink.RouteAttributes{ + OutIface: r.LinkIndex, + Dst: *dst, // there should always be a dst for routes + Gateway: netipAddrToNetIP(r.Gw), + Priority: r.Priority, + Table: extendedTable, + }, + } + if src != nil { // src is optional + message.Attributes.Src = *src + } + return message +} + +func (n *NetLink) RouteList(family uint8) (routes []Route, err error) { + conn, err := rtnetlink.Dial(nil) if err != nil { - return nil, err + return nil, fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + routeMessages, err := conn.Route.List() + if err != nil { + return nil, fmt.Errorf("listing interfaces: %w", err) } - routes = make([]Route, len(netlinkRoutes)) - for i := range netlinkRoutes { - routes[i] = netlinkRouteToRoute(netlinkRoutes[i]) + routes = make([]Route, 0, len(routeMessages)) + for _, routeMessage := range routeMessages { + if family != FamilyAll && routeMessage.Family != family { + continue + } + var route Route + route.fromMessage(routeMessage) + routes = append(routes, route) } return routes, nil } func (n *NetLink) RouteAdd(route Route) error { - netlinkRoute := routeToNetlinkRoute(route) - return netlink.RouteAdd(&netlinkRoute) + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + return conn.Route.Add(route.message()) } func (n *NetLink) RouteDel(route Route) error { - netlinkRoute := routeToNetlinkRoute(route) - return netlink.RouteDel(&netlinkRoute) + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + return conn.Route.Delete(route.message()) } func (n *NetLink) RouteReplace(route Route) error { - netlinkRoute := routeToNetlinkRoute(route) - return netlink.RouteReplace(&netlinkRoute) -} - -func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) { - return Route{ - LinkIndex: netlinkRoute.LinkIndex, - Dst: netIPNetToNetipPrefix(netlinkRoute.Dst), - Src: netIPToNetipAddress(netlinkRoute.Src), - Gw: netIPToNetipAddress(netlinkRoute.Gw), - Priority: netlinkRoute.Priority, - Family: netlinkRoute.Family, - Table: netlinkRoute.Table, - Type: netlinkRoute.Type, + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) } -} + defer conn.Close() -func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) { - return netlink.Route{ - LinkIndex: route.LinkIndex, - Dst: netipPrefixToIPNet(route.Dst), - Src: netipAddrToNetIP(route.Src), - Gw: netipAddrToNetIP(route.Gw), - Priority: route.Priority, - Family: route.Family, - Table: route.Table, - Type: route.Type, - } + return conn.Route.Replace(route.message()) } diff --git a/internal/netlink/route_linux.go b/internal/netlink/route_linux.go new file mode 100644 index 00000000..dc777785 --- /dev/null +++ b/internal/netlink/route_linux.go @@ -0,0 +1,11 @@ +package netlink + +import "golang.org/x/sys/unix" + +const ( + RouteTypeUnicast = unix.RTN_UNICAST + ScopeUniverse = unix.RT_SCOPE_UNIVERSE + ProtoStatic = unix.RTPROT_STATIC + + rtTableCompat = unix.RT_TABLE_COMPAT +) diff --git a/internal/netlink/rule.go b/internal/netlink/rule.go index 7591c8ee..ae090e06 100644 --- a/internal/netlink/rule.go +++ b/internal/netlink/rule.go @@ -2,54 +2,95 @@ package netlink import ( "fmt" + "net/netip" - "github.com/vishvananda/netlink" + "github.com/jsimonetti/rtnetlink" ) -func NewRule() Rule { - // defaults found from netlink.NewRule() for fields we use, - // the rest of the defaults is set when converting from a `Rule` - // to a `netlink.Rule` - return Rule{ - Priority: -1, - Mark: 0, +type Rule struct { + Priority *uint32 + Family uint8 + Table uint32 + Mark *uint32 + Src netip.Prefix + Dst netip.Prefix + Flags uint32 + Action uint8 +} + +func (r *Rule) fromMessage(message rtnetlink.RuleMessage) { + table := uint32(message.Table) + if table == 0 || table == rtTableCompat { + table = *message.Attributes.Table } + r.Priority = message.Attributes.Priority + r.Family = message.Family + r.Table = table + r.Mark = message.Attributes.FwMark + r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength) + r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength) + r.Flags = message.Flags + r.Action = message.Action } -func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) { - netlinkRule = *netlink.NewRule() - netlinkRule.Priority = rule.Priority - netlinkRule.Family = rule.Family - netlinkRule.Table = rule.Table - netlinkRule.Mark = rule.Mark - netlinkRule.Src = netipPrefixToIPNet(rule.Src) - netlinkRule.Dst = netipPrefixToIPNet(rule.Dst) - netlinkRule.Invert = rule.Invert - return netlinkRule -} +func (r Rule) message() *rtnetlink.RuleMessage { + src, srcLength := prefixToIPAndLength(r.Src) + dst, dstLength := prefixToIPAndLength(r.Dst) -func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) { - return Rule{ - Priority: netlinkRule.Priority, - Family: netlinkRule.Family, - Table: netlinkRule.Table, - Mark: netlinkRule.Mark, - Src: netIPNetToNetipPrefix(netlinkRule.Src), - Dst: netIPNetToNetipPrefix(netlinkRule.Dst), - Invert: netlinkRule.Invert, + message := &rtnetlink.RuleMessage{ + Family: r.Family, + SrcLength: srcLength, + DstLength: dstLength, + Flags: r.Flags, + Action: r.Action, + Attributes: &rtnetlink.RuleAttributes{ + Priority: r.Priority, + FwMark: r.Mark, + Src: src, + Dst: dst, + }, } + + if r.Table <= uint32(^uint8(0)) { + message.Table = uint8(r.Table) + } else { + message.Table = rtTableCompat + message.Attributes.Table = &r.Table + } + + return message } -func ruleDbgMsg(add bool, rule Rule) (debugMessage string) { +func (r Rule) String() string { + from := "all" + if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() { + from = r.Src.String() + } + + to := "all" + if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() { + to = r.Dst.String() + } + + priority := "" + if r.Priority != nil { + priority = fmt.Sprintf(" %d", *r.Priority) + } + + return fmt.Sprintf("ip rule%s: from %s to %s table %d", + priority, from, to, r.Table) +} + +func (r Rule) debugMessage(add bool) (debugMessage string) { debugMessage = "ip" - switch rule.Family { + switch r.Family { case FamilyV4: debugMessage += " -f inet" case FamilyV6: debugMessage += " -f inet6" default: - debugMessage += " -f " + fmt.Sprint(rule.Family) + debugMessage += " -f " + fmt.Sprint(r.Family) } debugMessage += " rule" @@ -60,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) { debugMessage += " del" } - if rule.Src.IsValid() { - debugMessage += " from " + rule.Src.String() + if r.Src.IsValid() { + debugMessage += " from " + r.Src.String() } - if rule.Dst.IsValid() { - debugMessage += " to " + rule.Dst.String() + if r.Dst.IsValid() { + debugMessage += " to " + r.Dst.String() } - if rule.Table != 0 { - debugMessage += " lookup " + fmt.Sprint(rule.Table) + if r.Table != 0 { + debugMessage += " lookup " + fmt.Sprint(r.Table) } - if rule.Priority != -1 { - debugMessage += " pref " + fmt.Sprint(rule.Priority) + if r.Priority != nil { + debugMessage += " pref " + fmt.Sprint(*r.Priority) } return debugMessage diff --git a/internal/netlink/rule_linux.go b/internal/netlink/rule_linux.go index a421ae80..08e2d79f 100644 --- a/internal/netlink/rule_linux.go +++ b/internal/netlink/rule_linux.go @@ -1,8 +1,18 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "fmt" -func (n *NetLink) RuleList(family int) (rules []Rule, err error) { + "github.com/jsimonetti/rtnetlink" + "golang.org/x/sys/unix" +) + +const ( + FlagInvert = unix.FIB_RULE_INVERT + ActionToTable = unix.FR_ACT_TO_TBL +) + +func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) { switch family { case FamilyAll: n.debugLogger.Debug("ip -4 rule list") @@ -12,26 +22,48 @@ func (n *NetLink) RuleList(family int) (rules []Rule, err error) { case FamilyV6: n.debugLogger.Debug("ip -6 rule list") } - netlinkRules, err := netlink.RuleList(family) + + conn, err := rtnetlink.Dial(nil) + if err != nil { + return nil, fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + ruleMessages, err := conn.Rule.List() if err != nil { return nil, err } - rules = make([]Rule, len(netlinkRules)) - for i := range netlinkRules { - rules[i] = netlinkRuleToRule(netlinkRules[i]) + rules = make([]Rule, 0, len(ruleMessages)) + for _, message := range ruleMessages { + if family != FamilyAll && family != message.Family { + continue + } + var rule Rule + rule.fromMessage(message) + rules = append(rules, rule) } return rules, nil } func (n *NetLink) RuleAdd(rule Rule) error { - n.debugLogger.Debug(ruleDbgMsg(true, rule)) - netlinkRule := ruleToNetlinkRule(rule) - return netlink.RuleAdd(&netlinkRule) + n.debugLogger.Debug(rule.debugMessage(true)) + + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + return conn.Rule.Add(rule.message()) } func (n *NetLink) RuleDel(rule Rule) error { - n.debugLogger.Debug(ruleDbgMsg(false, rule)) - netlinkRule := ruleToNetlinkRule(rule) - return netlink.RuleDel(&netlinkRule) + n.debugLogger.Debug(rule.debugMessage(false)) + + conn, err := rtnetlink.Dial(nil) + if err != nil { + return fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + return conn.Rule.Delete(rule.message()) } diff --git a/internal/netlink/rule_test.go b/internal/netlink/rule_test.go index f056076d..e7c8be8b 100644 --- a/internal/netlink/rule_test.go +++ b/internal/netlink/rule_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_ruleDbgMsg(t *testing.T) { +func Test_Rule_debugMessage(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) { dbgMsg string }{ "default values": { - dbgMsg: "ip -f 0 rule del pref 0", + dbgMsg: "ip -f 0 rule del", }, "add rule": { add: true, @@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) { Src: makeNetipPrefix(1), Dst: makeNetipPrefix(2), Table: 100, - Priority: 101, + Priority: ptrTo(uint32(101)), }, dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", }, @@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) { Src: makeNetipPrefix(1), Dst: makeNetipPrefix(2), Table: 100, - Priority: 101, + Priority: ptrTo(uint32(101)), }, dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", }, @@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - dbgMsg := ruleDbgMsg(testCase.add, testCase.rule) + dbgMsg := testCase.rule.debugMessage(testCase.add) assert.Equal(t, testCase.dbgMsg, dbgMsg) }) diff --git a/internal/netlink/types.go b/internal/netlink/types.go deleted file mode 100644 index 3633beda..00000000 --- a/internal/netlink/types.go +++ /dev/null @@ -1,58 +0,0 @@ -package netlink - -import ( - "fmt" - "net/netip" -) - -type Addr struct { - Network netip.Prefix -} - -func (a Addr) String() string { - return a.Network.String() -} - -type Link struct { - Type string - Name string - Index int - EncapType string - MTU uint16 -} - -type Route struct { - LinkIndex int - Dst netip.Prefix - Src netip.Addr - Gw netip.Addr - Priority int - Family int - Table int - Type int -} - -type Rule struct { - Priority int - Family int - Table int - Mark uint32 - Src netip.Prefix - Dst netip.Prefix - Invert bool -} - -func (r Rule) String() string { - from := "all" - if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() { - from = r.Src.String() - } - - to := "all" - if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() { - to = r.Dst.String() - } - - return fmt.Sprintf("ip rule %d: from %s to %s table %d", - r.Priority, from, to, r.Table) -} diff --git a/internal/netlink/wireguard_linux.go b/internal/netlink/wireguard_linux.go index 5d8d1366..94ba9ebb 100644 --- a/internal/netlink/wireguard_linux.go +++ b/internal/netlink/wireguard_linux.go @@ -1,35 +1,58 @@ package netlink import ( + "errors" + "fmt" + "os" + + "github.com/mdlayher/genetlink" "github.com/qdm12/gluetun/internal/mod" - "github.com/vishvananda/netlink" ) -func (n *NetLink) IsWireguardSupported() bool { +func (n *NetLink) IsWireguardSupported() (ok bool, err error) { // Check for Wireguard family without loading the wireguard module. // Some kernels have the wireguard module built-in, and don't have a // modules directory, such as WSL2 kernels. - ok := hasWireguardFamily() - if ok { - return true + ok, err = hasWireguardFamily() + if err != nil { + return false, fmt.Errorf("checking wireguard family: %w", err) + } else if ok { + return true, nil } // Try loading the wireguard module, since some systems do not load // it after a boot. If this fails, wireguard is assumed to not be supported. n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module") - err := mod.Probe("wireguard") + err = mod.Probe("wireguard") if err != nil { n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err) - return false + return false, nil } n.debugLogger.Debugf("wireguard kernel module loaded successfully") // Re-check if the Wireguard family is now available, after loading // the wireguard kernel module. - return hasWireguardFamily() + ok, err = hasWireguardFamily() + if err != nil { + return false, fmt.Errorf("checking wireguard family: %w", err) + } + return ok, nil } -func hasWireguardFamily() bool { - _, err := netlink.GenlFamilyGet("wireguard") - return err == nil +func hasWireguardFamily() (ok bool, err error) { + conn, err := genetlink.Dial(nil) + if err != nil { + return false, fmt.Errorf("dialing netlink: %w", err) + } + defer conn.Close() + + _, err = conn.GetFamily("wireguard") + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, fmt.Errorf("getting wireguard family: %w", err) + } + + return true, nil } diff --git a/internal/netlink/wireguard_test.go b/internal/netlink/wireguard_test.go index 229f69fc..33b9e2bf 100644 --- a/internal/netlink/wireguard_test.go +++ b/internal/netlink/wireguard_test.go @@ -4,6 +4,8 @@ package netlink import ( "testing" + + "github.com/stretchr/testify/require" ) func Test_NetLink_IsWireguardSupported(t *testing.T) { @@ -12,7 +14,8 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) { netLink := &NetLink{ debugLogger: &noopLogger{}, } - ok := netLink.IsWireguardSupported() + ok, err := netLink.IsWireguardSupported() + require.NoError(t, err) if ok { // cannot assert since this depends on kernel t.Log("wireguard is supported") } else { diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index fb442d5e..93277b48 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -14,7 +14,7 @@ type Service interface { type Routing interface { VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) - AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) + AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) } type PortAllower interface { diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index 01876be8..33288a30 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -17,7 +17,7 @@ type PortAllower interface { type Routing interface { VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) - AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) + AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) } type Logger interface { diff --git a/internal/routing/default.go b/internal/routing/default.go index b5696c52..3027bd6f 100644 --- a/internal/routing/default.go +++ b/internal/routing/default.go @@ -14,7 +14,7 @@ type DefaultRoute struct { NetInterface string Gateway netip.Addr AssignedIP netip.Addr - Family int + Family uint8 } func (d DefaultRoute) String() string { diff --git a/internal/routing/inbound.go b/internal/routing/inbound.go index 239f5666..63357d67 100644 --- a/internal/routing/inbound.go +++ b/internal/routing/inbound.go @@ -8,8 +8,8 @@ import ( ) const ( - inboundTable = 200 - inboundPriority = 100 + inboundTable uint32 = 200 + inboundPriority uint32 = 100 ) func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) { @@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e return nil } -func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { +func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) { for _, defaultRoute := range defaultRoutes { assignedIP := defaultRoute.AssignedIP bits := 32 @@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo return nil } -func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { +func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) { for _, defaultRoute := range defaultRoutes { assignedIP := defaultRoute.AssignedIP bits := 32 diff --git a/internal/routing/ip.go b/internal/routing/ip.go index 72312dba..a43ced75 100644 --- a/internal/routing/ip.go +++ b/internal/routing/ip.go @@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool { var errInterfaceIPNotFound = errors.New("IP address not found for interface") -func ipMatchesFamily(ip netip.Addr, family int) bool { +func ipMatchesFamily(ip netip.Addr, family uint8) bool { return (family == netlink.FamilyV4 && ip.Is4()) || (family == netlink.FamilyV6 && ip.Is6()) } -func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) { +func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err) diff --git a/internal/routing/local.go b/internal/routing/local.go index 674f5857..564ec120 100644 --- a/internal/routing/local.go +++ b/internal/routing/local.go @@ -26,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { return localNetworks, fmt.Errorf("listing links: %w", err) } - localLinks := make(map[int]struct{}) + localLinks := make(map[uint32]struct{}) for _, link := range links { - if link.EncapType != "ether" { + if link.DeviceType != netlink.DeviceTypeEthernet { continue } @@ -95,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) { // Local has higher priority then outbound(99) and inbound(100) as the // local routes might be necessary to reach the outbound/inbound routes. - const localPriority = 98 + const localPriority uint32 = 98 // Main table was setup correctly by Docker, just need to add rules to use it src := netip.Prefix{} diff --git a/internal/routing/mocks_test.go b/internal/routing/mocks_test.go index f9ddd85d..1eccf147 100644 --- a/internal/routing/mocks_test.go +++ b/internal/routing/mocks_test.go @@ -5,6 +5,7 @@ package routing import ( + netip "net/netip" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder { } // AddrList mocks base method. -func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) { +func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddrList", arg0, arg1) - ret0, _ := ret[0].([]netlink.Addr) + ret0, _ := ret[0].([]netip.Prefix) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca } // AddrReplace mocks base method. -func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { +func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret0, _ := ret[0].(error) @@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock } // LinkAdd mocks base method. -func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkAdd", arg0) - ret0, _ := ret[0].(int) + ret0, _ := ret[0].(uint32) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call { } // LinkByIndex mocks base method. -func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) { +func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkByIndex", arg0) ret0, _ := ret[0].(netlink.Link) @@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { } // LinkDel mocks base method. -func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkDel(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkDel", arg0) ret0, _ := ret[0].(error) @@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call { } // LinkSetDown mocks base method. -func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkSetDown(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetDown", arg0) ret0, _ := ret[0].(error) @@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call } // LinkSetUp mocks base method. -func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { +func (m *MockNetLinker) LinkSetUp(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetUp", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(error) + return ret0 } // LinkSetUp indicates an expected call of LinkSetUp. @@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call { } // RouteList mocks base method. -func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) { +func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteList", arg0) ret0, _ := ret[0].([]netlink.Route) @@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call { } // RuleList mocks base method. -func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) { +func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RuleList", arg0) ret0, _ := ret[0].([]netlink.Rule) diff --git a/internal/routing/outbound.go b/internal/routing/outbound.go index 828906f7..74a107dd 100644 --- a/internal/routing/outbound.go +++ b/internal/routing/outbound.go @@ -9,8 +9,8 @@ import ( ) const ( - outboundTable = 199 - outboundPriority = 99 + outboundTable uint32 = 199 + outboundPriority uint32 = 99 ) func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error { diff --git a/internal/routing/routes.go b/internal/routing/routes.go index e212d0b2..37531d78 100644 --- a/internal/routing/routes.go +++ b/internal/routing/routes.go @@ -9,25 +9,33 @@ import ( ) func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr, - iface string, table int, + iface string, table uint32, ) error { destinationStr := destination.String() r.logger.Info("adding route for " + destinationStr) r.logger.Debug("ip route replace " + destinationStr + " via " + gateway.String() + " dev " + iface + - " table " + strconv.Itoa(table)) + " table " + strconv.Itoa(int(table))) link, err := r.netLinker.LinkByName(iface) if err != nil { return fmt.Errorf("finding link for interface %s: %w", iface, err) } + family := netlink.FamilyV4 + if destination.Addr().Is6() { + family = netlink.FamilyV6 + } route := netlink.Route{ Dst: destination, Gw: gateway, LinkIndex: link.Index, + Family: family, Table: table, + Type: netlink.RouteTypeUnicast, + Scope: netlink.ScopeUniverse, + Proto: netlink.ProtoStatic, } if err := r.netLinker.RouteReplace(route); err != nil { return fmt.Errorf("replacing route for subnet %s at interface %s: %w", @@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr, } func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr, - iface string, table int, + iface string, table uint32, ) (err error) { destinationStr := destination.String() r.logger.Info("deleting route for " + destinationStr) r.logger.Debug("ip route delete " + destinationStr + " via " + gateway.String() + " dev " + iface + - " table " + strconv.Itoa(table)) + " table " + strconv.Itoa(int(table))) link, err := r.netLinker.LinkByName(iface) if err != nil { return fmt.Errorf("finding link for interface %s: %w", iface, err) } + family := netlink.FamilyV4 + if destination.Addr().Is6() { + family = netlink.FamilyV6 + } route := netlink.Route{ Dst: destination, Gw: gateway, LinkIndex: link.Index, + Family: family, Table: table, } if err := r.netLinker.RouteDel(route); err != nil { diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 1420e923..9609cb28 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -15,20 +15,20 @@ type NetLinker interface { } type Addresser interface { - AddrList(link netlink.Link, family int) ( - addresses []netlink.Addr, err error) - AddrReplace(link netlink.Link, addr netlink.Addr) error + AddrList(linkIndex uint32, family uint8) ( + addresses []netip.Prefix, err error) + AddrReplace(linkIndex uint32, prefix netip.Prefix) error } type Router interface { - RouteList(family int) (routes []netlink.Route, err error) + RouteList(family uint8) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error RouteDel(route netlink.Route) error RouteReplace(route netlink.Route) error } type Ruler interface { - RuleList(family int) (rules []netlink.Rule, err error) + RuleList(family uint8) (rules []netlink.Rule, err error) RuleAdd(rule netlink.Rule) error RuleDel(rule netlink.Rule) error } @@ -36,11 +36,11 @@ type Ruler interface { type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkByIndex(index int) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (linkIndex int, err error) - LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (linkIndex int, err error) - LinkSetDown(link netlink.Link) (err error) + LinkByIndex(index uint32) (link netlink.Link, err error) + LinkAdd(link netlink.Link) (linkIndex uint32, err error) + LinkDel(index uint32) (err error) + LinkSetUp(index uint32) (err error) + LinkSetDown(index uint32) (err error) } type Routing struct { diff --git a/internal/routing/rules.go b/internal/routing/rules.go index f6ea3d0f..70308a72 100644 --- a/internal/routing/rules.go +++ b/internal/routing/rules.go @@ -7,12 +7,19 @@ import ( "github.com/qdm12/gluetun/internal/netlink" ) -func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error { - rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst - rule.Priority = priority - rule.Table = table +func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error { + family := netlink.FamilyV4 + if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) { + family = netlink.FamilyV6 + } + rule := netlink.Rule{ + Priority: &priority, + Family: family, + Table: table, + Src: src, + Dst: dst, + Action: netlink.ActionToTable, + } existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) if err != nil { @@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error { return nil } -func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error { - rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst - rule.Priority = priority - rule.Table = table +func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error { + family := netlink.FamilyV4 + if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) { + family = netlink.FamilyV6 + } + rule := netlink.Rule{ + Priority: &priority, + Family: family, + Table: table, + Src: src, + Dst: dst, + Action: netlink.ActionToTable, + } existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) if err != nil { @@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error return nil } +// rulesAreEqual checks whether two rules are equal +// only according to src, dst, priority and table. func rulesAreEqual(a, b netlink.Rule) bool { return ipPrefixesAreEqual(a.Src, b.Src) && ipPrefixesAreEqual(a.Dst, b.Dst) && - a.Priority == b.Priority && + ptrsEqual(a.Priority, b.Priority) && a.Table == b.Table } @@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool { return a.Bits() == b.Bits() && a.Addr().Compare(b.Addr()) == 0 } + +func ptrsEqual(a, b *uint32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} diff --git a/internal/routing/rules_test.go b/internal/routing/rules_test.go index 54c790ef..49d7369e 100644 --- a/internal/routing/rules_test.go +++ b/internal/routing/rules_test.go @@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix { } func makeIPRule(src, dst netip.Prefix, - table, priority int, + table uint32, priority uint32, ) netlink.Rule { - rule := netlink.NewRule() - rule.Src = src - rule.Dst = dst - rule.Table = table - rule.Priority = priority - return rule + family := netlink.FamilyV4 + if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) { + family = netlink.FamilyV6 + } + return netlink.Rule{ + Priority: &priority, + Family: family, + Table: table, + Src: src, + Dst: dst, + Action: netlink.ActionToTable, + } } func Test_Routing_addIPRule(t *testing.T) { @@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) { testCases := map[string]struct { src netip.Prefix dst netip.Prefix - table int - priority int + table uint32 + priority uint32 ruleList ruleListCall ruleAdd ruleAddCall err error @@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) { testCases := map[string]struct { src netip.Prefix dst netip.Prefix - table int - priority int + table uint32 + priority uint32 ruleList ruleListCall ruleDel ruleDelCall err error @@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) { } } +func ptrTo[T any](v T) *T { return &v } + func Test_rulesAreEqual(t *testing.T) { t.Parallel() @@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) { a: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, b: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, }, @@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) { a: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, b: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, }, @@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) { a: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 999, + Priority: ptrTo(uint32(999)), Table: 101, }, b: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, }, @@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) { a: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, - Table: 999, + Priority: ptrTo(uint32(100)), + Table: 102, }, b: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, }, @@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) { a: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, b: netlink.Rule{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), - Priority: 100, + Priority: ptrTo(uint32(100)), Table: 101, }, equal: true, diff --git a/internal/routing/vpn.go b/internal/routing/vpn.go index 793b6299..fc68e148 100644 --- a/internal/routing/vpn.go +++ b/internal/routing/vpn.go @@ -33,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN return route.Gw, nil case route.Dst.IsSingleIP() && - route.Dst.Addr().Compare(route.Src) == 0 && + route.Dst.Addr().Compare(route.Src.Addr()) == 0 && route.Table == tableLocal: // Wireguard - route.Src = route.Src.Unmap() - if route.Src.Is6() { + if route.Src.Addr().Is6() { return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src) } - bytes := route.Src.As4() + bytes := route.Src.Addr().As4() // force last byte to 1 to get the VPN gateway IP // This is not necessarily bullet proof but it seems to work. bytes[3] = 1 diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index b6295bea..8467ea66 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -57,15 +57,15 @@ type Storage interface { } type NetLinker interface { - AddrReplace(link netlink.Link, addr netlink.Addr) error + AddrReplace(linkIndex uint32, addr netip.Prefix) error Router Ruler Linker - IsWireguardSupported() bool + IsWireguardSupported() (ok bool, err error) } type Router interface { - RouteList(family int) (routes []netlink.Route, err error) + RouteList(family uint8) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error } @@ -77,11 +77,11 @@ type Ruler interface { type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (linkIndex int, err error) - LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (linkIndex int, err error) - LinkSetDown(link netlink.Link) (err error) - LinkSetMTU(link netlink.Link, mtu uint32) (err error) + LinkAdd(link netlink.Link) (linkIndex uint32, err error) + LinkDel(linkIndex uint32) error + LinkSetUp(linkIndex uint32) error + LinkSetDown(linkIndex uint32) error + LinkSetMTU(linkIndex, mtu uint32) error } type DNSLoop interface { diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index e477e76e..438399d6 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -172,7 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, // the new MTU is set again, but this is necessary to find the highest valid MTU. logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU) - err = netlinker.LinkSetMTU(link, vpnLinkMTU) + err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU) if err != nil { return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) } @@ -183,14 +183,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, case err == nil: logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted): - vpnLinkMTU = uint32(originalMTU) + vpnLinkMTU = originalMTU logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)", vpnInterface, originalMTU, err) default: return fmt.Errorf("path MTU discovering: %w", err) } - err = netlinker.LinkSetMTU(link, vpnLinkMTU) + err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU) if err != nil { return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) } diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index 2c9c9124..a85cd7f8 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -3,26 +3,20 @@ package wireguard import ( "fmt" "net/netip" - - "github.com/qdm12/gluetun/internal/netlink" ) -func (w *Wireguard) addAddresses(link netlink.Link, +func (w *Wireguard) addAddresses(linkIndex uint32, addresses []netip.Prefix, ) (err error) { - for _, ipNet := range addresses { - if !*w.settings.IPv6 && ipNet.Addr().Is6() { + for _, address := range addresses { + if !*w.settings.IPv6 && address.Addr().Is6() { continue } - address := netlink.Addr{ - Network: ipNet, - } - - err = w.netlink.AddrReplace(link, address) + err = w.netlink.AddrReplace(linkIndex, address) if err != nil { - return fmt.Errorf("%w: when adding address %s to link %s", - err, address, link.Name) + return fmt.Errorf("%w: when adding address %s to link with index %d", + err, address, linkIndex) } } diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index d1851160..20707c4c 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) { errDummy := errors.New("dummy") testCases := map[string]struct { - link netlink.Link + linkIndex uint32 addrs []netip.Prefix - wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard + wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard err error }{ "success": { - link: netlink.Link{Type: "wireguard"}, - addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + linkIndex: 1, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrReplace(link, netlink.Addr{Network: ipNetOne}). + AddrReplace(linkIndex, ipNetOne). Return(nil) netLinker.EXPECT(). - AddrReplace(link, netlink.Addr{Network: ipNetTwo}). + AddrReplace(linkIndex, ipNetTwo). Return(nil).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, }, "first add error": { - link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, - addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + linkIndex: 1, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { netLinker := NewMockNetLinker(ctrl) netLinker.EXPECT(). - AddrReplace(link, netlink.Addr{Network: ipNetOne}). + AddrReplace(linkIndex, ipNetOne). Return(errDummy) return &Wireguard{ netlink: netLinker, @@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, } }, - err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), + err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"), }, "second add error": { - link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, - addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + linkIndex: 1, + addrs: []netip.Prefix{ipNetOne, ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrReplace(link, netlink.Addr{Network: ipNetOne}). + AddrReplace(linkIndex, ipNetOne). Return(nil) netLinker.EXPECT(). - AddrReplace(link, netlink.Addr{Network: ipNetTwo}). + AddrReplace(linkIndex, ipNetTwo). Return(errDummy).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, } }, - err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"), + err: errors.New("dummy: when adding address ::1234/64 to link with index 1"), }, "ignore IPv6": { addrs: []netip.Prefix{ipNetTwo}, - wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard { + wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard { return &Wireguard{ settings: Settings{ IPv6: ptrTo(false), @@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - wg := testCase.wgBuilder(ctrl, testCase.link) + wg := testCase.wgBuilder(ctrl, testCase.linkIndex) - err := wg.addAddresses(testCase.link, testCase.addrs) + err := wg.addAddresses(testCase.linkIndex, testCase.addrs) if testCase.err != nil { require.Error(t, err) diff --git a/internal/wireguard/helpers_test.go b/internal/wireguard/helpers_test.go index 2db84e77..3be81856 100644 --- a/internal/wireguard/helpers_test.go +++ b/internal/wireguard/helpers_test.go @@ -1,3 +1,53 @@ package wireguard +import ( + "math/rand/v2" + "net/netip" + + "github.com/qdm12/gluetun/internal/netlink" +) + func ptrTo[T any](x T) *T { return &x } + +var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals + +func makeLinkName() string { + const alphabet = "abcdefghijklmnopqrstuvwxyz" + b := make([]byte, 8) + for i := range b { + b[i] = alphabet[rng.IntN(len(alphabet))] + } + return "test" + string(b) +} + +func rulesAreEqual(a, b netlink.Rule) bool { + return ipPrefixesAreEqual(a.Src, b.Src) && + ipPrefixesAreEqual(a.Dst, b.Dst) && + ptrsEqual(a.Priority, b.Priority) && + a.Table == b.Table && + a.Family == b.Family && + a.Flags == b.Flags && + a.Action == b.Action && + ptrsEqual(a.Mark, b.Mark) +} + +func ipPrefixesAreEqual(a, b netip.Prefix) bool { + if !a.IsValid() && !b.IsValid() { + return true + } + if !a.IsValid() || !b.IsValid() { + return false + } + return a.Bits() == b.Bits() && + a.Addr().Compare(b.Addr()) == 0 +} + +func ptrsEqual(a, b *uint32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} diff --git a/internal/wireguard/netlink_integration_test.go b/internal/wireguard/netlink_integration_test.go index 7de40dd5..e4883799 100644 --- a/internal/wireguard/netlink_integration_test.go +++ b/internal/wireguard/netlink_integration_test.go @@ -1,4 +1,4 @@ -//go:build netlink && linux +//go:build linux package wireguard @@ -10,13 +10,16 @@ import ( "github.com/qdm12/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" ) type noopDebugLogger struct{} -func (n noopDebugLogger) Debugf(format string, args ...any) {} -func (n noopDebugLogger) Patch(options ...log.Option) {} +func (n noopDebugLogger) Debug(_ string) {} +func (n noopDebugLogger) Debugf(_ string, _ ...any) {} +func (n noopDebugLogger) Info(_ string) {} +func (n noopDebugLogger) Error(_ string) {} +func (n noopDebugLogger) Errorf(_ string, _ ...any) {} +func (n noopDebugLogger) Patch(_ ...log.Option) {} func Test_netlink_Wireguard_addAddresses(t *testing.T) { t.Parallel() @@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { netlinker := netlink.New(&noopDebugLogger{}) link := netlink.Link{ - Type: "bridge", - Name: "test_8081", - } - - // Remove any previously created test interface from a crashed/panic - // test or test suite run. - err := netlinker.LinkDel(link) - if err != nil && err.Error() != "invalid argument" { - require.NoError(t, err) + DeviceType: netlink.DeviceTypeNone, + VirtualType: "bridge", + Name: makeLinkName(), } linkIndex, err := netlinker.LinkAdd(link) @@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { link.Index = linkIndex defer func() { - err = netlinker.LinkDel(link) + err = netlinker.LinkDel(linkIndex) assert.NoError(t, err) }() @@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { } const addIterations = 2 // initial + replace - - for i := 0; i < addIterations; i++ { - err = wg.addAddresses(link, addresses) + for range addIterations { + err = wg.addAddresses(link.Index, addresses) require.NoError(t, err) - netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll) + ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll) require.NoError(t, err) - require.Equal(t, len(addresses), len(netlinkAddresses)) - for i, netlinkAddress := range netlinkAddresses { - require.NotNil(t, netlinkAddress.Network) - assert.Equal(t, addresses[i], netlinkAddress.Network) + require.Equal(t, len(addresses), len(ipPrefixes)) + for i, ipPrefix := range ipPrefixes { + assert.Equal(t, addresses[i], ipPrefix) } } } @@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { netlinker := netlink.New(&noopDebugLogger{}) wg := &Wireguard{ netlink: netlinker, + logger: &noopDebugLogger{}, } - rulePriority := 10000 - const firewallMark = 999 - const family = unix.AF_INET // ipv4 + // Unique combination for this test + const rulePriority uint32 = 10000 + const firewallMark uint32 = 12345 + const family = netlink.FamilyV4 cleanup, err := wg.addRule(rulePriority, firewallMark, family) require.NoError(t, err) - defer func() { + t.Cleanup(func() { err := cleanup() assert.NoError(t, err) - }() + }) rules, err := netlinker.RuleList(netlink.FamilyV4) require.NoError(t, err) + expectedRule := netlink.Rule{ + Priority: ptrTo(rulePriority), + Family: netlink.FamilyV4, + Table: firewallMark, + Mark: ptrTo(firewallMark), + Flags: netlink.FlagInvert, + Action: netlink.ActionToTable, + } var rule netlink.Rule var ruleFound bool for _, rule = range rules { - if rule.Mark == firewallMark { + if rulesAreEqual(rule, expectedRule) { ruleFound = true break } } require.True(t, ruleFound) - expectedRule := netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, - Table: firewallMark, - } - assert.Equal(t, expectedRule, rule) // Existing rule cannot be added nilCleanup, err := wg.addRule(rulePriority, @@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { _ = nilCleanup() // in case it succeeds } require.Error(t, err) - assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists") + assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists") } diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index 118620aa..4f4a3fa8 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -1,19 +1,23 @@ package wireguard -import "github.com/qdm12/gluetun/internal/netlink" +import ( + "net/netip" + + "github.com/qdm12/gluetun/internal/netlink" +) //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker type NetLinker interface { - AddrReplace(link netlink.Link, addr netlink.Addr) error + AddrReplace(linkIndex uint32, addr netip.Prefix) error Router Ruler Linker - IsWireguardSupported() bool + IsWireguardSupported() (ok bool, err error) } type Router interface { - RouteList(family int) (routes []netlink.Route, err error) + RouteList(family uint8) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error } @@ -23,10 +27,10 @@ type Ruler interface { } type Linker interface { - LinkAdd(link netlink.Link) (linkIndex int, err error) + LinkAdd(link netlink.Link) (linkIndex uint32, err error) LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkSetUp(link netlink.Link) (linkIndex int, err error) - LinkSetDown(link netlink.Link) error - LinkDel(link netlink.Link) error + LinkSetUp(linkIndex uint32) error + LinkSetDown(linkIndex uint32) error + LinkDel(linkIndex uint32) error } diff --git a/internal/wireguard/netlinker_mock_test.go b/internal/wireguard/netlinker_mock_test.go index c5fac81f..5fc76bf0 100644 --- a/internal/wireguard/netlinker_mock_test.go +++ b/internal/wireguard/netlinker_mock_test.go @@ -5,6 +5,7 @@ package wireguard import ( + netip "net/netip" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder { } // AddrReplace mocks base method. -func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { +func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret0, _ := ret[0].(error) @@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock } // IsWireguardSupported mocks base method. -func (m *MockNetLinker) IsWireguardSupported() bool { +func (m *MockNetLinker) IsWireguardSupported() (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IsWireguardSupported") ret0, _ := ret[0].(bool) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // IsWireguardSupported indicates an expected call of IsWireguardSupported. @@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { } // LinkAdd mocks base method. -func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkAdd", arg0) - ret0, _ := ret[0].(int) + ret0, _ := ret[0].(uint32) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { } // LinkDel mocks base method. -func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkDel(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkDel", arg0) ret0, _ := ret[0].(error) @@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call { } // LinkSetDown mocks base method. -func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkSetDown(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetDown", arg0) ret0, _ := ret[0].(error) @@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call } // LinkSetUp mocks base method. -func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { +func (m *MockNetLinker) LinkSetUp(arg0 uint32) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetUp", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(error) + return ret0 } // LinkSetUp indicates an expected call of LinkSetUp. @@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call { } // RouteList mocks base method. -func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) { +func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteList", arg0) ret0, _ := ret[0].([]netlink.Route) diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index 9893dc76..101b39fd 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -8,11 +8,11 @@ import ( "github.com/qdm12/gluetun/internal/netlink" ) -func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix, +func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix, firewallMark uint32, ) (err error) { for _, dst := range destinations { - err = w.addRoute(link, dst, firewallMark) + err = w.addRoute(linkIndex, dst, firewallMark) if err == nil { continue } @@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix, return nil } -func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, +func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix, firewallMark uint32, ) (err error) { family := netlink.FamilyV4 @@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, family = netlink.FamilyV6 } route := netlink.Route{ - LinkIndex: link.Index, + LinkIndex: linkIndex, Dst: dst, Family: family, - Table: int(firewallMark), + Table: firewallMark, + Type: netlink.RouteTypeUnicast, + Scope: netlink.ScopeUniverse, + Proto: netlink.ProtoStatic, } err = w.netlink.RouteAdd(route) if err != nil { return fmt.Errorf( - "adding route for link %s, destination %s and table %d: %w", - link.Name, dst, firewallMark, err) + "adding route for link with index %d, destination %s and table %d: %w", + linkIndex, dst, firewallMark, err) } return err diff --git a/internal/wireguard/route_test.go b/internal/wireguard/route_test.go index ce05ac60..c75c8034 100644 --- a/internal/wireguard/route_test.go +++ b/internal/wireguard/route_test.go @@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) { errDummy := errors.New("dummy") testCases := map[string]struct { - link netlink.Link dst netip.Prefix expectedRoute netlink.Route routeAddErr error err error }{ "success": { - link: netlink.Link{ - Index: linkIndex, - }, dst: ipPrefix, expectedRoute: netlink.Route{ LinkIndex: linkIndex, Dst: ipPrefix, Family: netlink.FamilyV4, Table: firewallMark, + Type: netlink.RouteTypeUnicast, + Scope: netlink.ScopeUniverse, + Proto: netlink.ProtoStatic, }, }, "route add error": { - link: netlink.Link{ - Name: "a_bridge", - Index: linkIndex, - }, dst: ipPrefix, expectedRoute: netlink.Route{ LinkIndex: linkIndex, Dst: ipPrefix, Family: netlink.FamilyV4, Table: firewallMark, + Type: netlink.RouteTypeUnicast, + Scope: netlink.ScopeUniverse, + Proto: netlink.ProtoStatic, }, routeAddErr: errDummy, - err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll + err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll }, } @@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) { RouteAdd(testCase.expectedRoute). Return(testCase.routeAddErr) - err := wg.addRoute(testCase.link, testCase.dst, firewallMark) + err := wg.addRoute(linkIndex, testCase.dst, firewallMark) if testCase.err != nil { require.Error(t, err) diff --git a/internal/wireguard/rule.go b/internal/wireguard/rule.go index c7cd0b23..24249cb5 100644 --- a/internal/wireguard/rule.go +++ b/internal/wireguard/rule.go @@ -7,15 +7,17 @@ import ( "github.com/qdm12/gluetun/internal/netlink" ) -func (w *Wireguard) addRule(rulePriority int, firewallMark uint32, - family int, +func (w *Wireguard) addRule(rulePriority, firewallMark uint32, + family uint8, ) (cleanup func() error, err error) { - rule := netlink.NewRule() - rule.Invert = true - rule.Priority = rulePriority - rule.Mark = firewallMark - rule.Table = int(firewallMark) - rule.Family = family + rule := netlink.Rule{ + Priority: &rulePriority, + Family: family, + Table: firewallMark, + Mark: &firewallMark, + Flags: netlink.FlagInvert, + Action: netlink.ActionToTable, + } if err := w.netlink.RuleAdd(rule); err != nil { if strings.HasSuffix(err.Error(), "file exists") { w.logger.Info("if you are using Kubernetes, this may fix the error below: " + diff --git a/internal/wireguard/rule_test.go b/internal/wireguard/rule_test.go index 766dbede..c6f02f36 100644 --- a/internal/wireguard/rule_test.go +++ b/internal/wireguard/rule_test.go @@ -8,15 +8,14 @@ import ( "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" ) func Test_Wireguard_addRule(t *testing.T) { t.Parallel() - const rulePriority = 987 - const firewallMark = 456 - const family = unix.AF_INET + const rulePriority uint32 = 987 + const firewallMark uint32 = 456 + const family = netlink.FamilyV4 errDummy := errors.New("dummy") @@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) { }{ "success": { expectedRule: netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, + Priority: ptrTo(rulePriority), + Mark: ptrTo(firewallMark), Table: firewallMark, Family: family, + Flags: netlink.FlagInvert, + Action: netlink.ActionToTable, }, }, "rule add error": { expectedRule: netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, + Priority: ptrTo(rulePriority), + Mark: ptrTo(firewallMark), Table: firewallMark, Family: family, + Flags: netlink.FlagInvert, + Action: netlink.ActionToTable, }, ruleAddErr: errDummy, err: errors.New("adding ip rule 987: from all to all table 456: dummy"), }, "rule delete error": { expectedRule: netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, + Priority: ptrTo(rulePriority), + Mark: ptrTo(firewallMark), Table: firewallMark, Family: family, + Flags: netlink.FlagInvert, + Action: netlink.ActionToTable, }, ruleDelErr: errDummy, cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"), diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index 9b58c120..d4f2d1f1 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -14,6 +14,7 @@ import ( ) var ( + ErrDetectKernel = errors.New("cannot detect Kernel support") ErrCreateTun = errors.New("cannot create TUN device") ErrAddLink = errors.New("cannot add Wireguard link") ErrFindLink = errors.New("cannot find link") @@ -32,7 +33,11 @@ var ( // See https://git.zx2c4.com/wireguard-go/tree/main.go func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { - kernelSupported := w.netlink.IsWireguardSupported() + kernelSupported, err := w.netlink.IsWireguardSupported() + if err != nil { + waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err) + return + } setupFunction := setupUserSpace switch w.settings.Implementation { @@ -65,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< defer closers.cleanup(w.logger) - link, waitAndCleanup, err := setupFunction(ctx, + linkIndex, waitAndCleanup, err := setupFunction(ctx, w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger) if err != nil { waitError <- err return } - err = w.addAddresses(link, w.settings.Addresses) + err = w.addAddresses(linkIndex, w.settings.Addresses) if err != nil { waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err) return @@ -85,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } - linkIndex, err := w.netlink.LinkSetUp(link) + err = w.netlink.LinkSetUp(linkIndex) if err != nil { waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) return } - link.Index = linkIndex closers.add("shutting down link", stepFour, func() error { - return w.netlink.LinkSetDown(link) + return w.netlink.LinkSetDown(linkIndex) }) - err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark) + err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark) if err != nil { waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err) return @@ -131,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< type waitAndCleanupFunc func() error func setupKernelSpace(ctx context.Context, - interfaceName string, netLinker NetLinker, mtu uint16, + interfaceName string, netLinker NetLinker, mtu uint32, closers *closers, logger Logger) ( - link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error, + linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error, ) { - link = netlink.Link{ - Type: "wireguard", - Name: interfaceName, - MTU: mtu, - } links, err := netLinker.LinkList() if err != nil { - return link, nil, fmt.Errorf("listing links: %w", err) + return 0, nil, fmt.Errorf("listing links: %w", err) } // Cleanup any previous Wireguard interface with the same name // See https://github.com/qdm12/gluetun/issues/1669 for _, link := range links { - if link.Type == "wireguard" && link.Name == interfaceName { - err = netLinker.LinkDel(link) + if link.VirtualType == "wireguard" && link.Name == interfaceName { + err = netLinker.LinkDel(link.Index) if err != nil { - return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w", + return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w", interfaceName, err) } } } - linkIndex, err := netLinker.LinkAdd(link) - if err != nil { - return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err) + link := netlink.Link{ + VirtualType: "wireguard", + Name: interfaceName, + MTU: mtu, + } + linkIndex, err = netLinker.LinkAdd(link) + if err != nil { + return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err) } - link.Index = linkIndex closers.add("deleting link", stepFive, func() error { - return netLinker.LinkDel(link) + return netLinker.LinkDel(linkIndex) }) waitAndCleanup = func() error { @@ -172,35 +175,35 @@ func setupKernelSpace(ctx context.Context, return ctx.Err() } - return link, waitAndCleanup, nil + return linkIndex, waitAndCleanup, nil } func setupUserSpace(ctx context.Context, - interfaceName string, netLinker NetLinker, mtu uint16, + interfaceName string, netLinker NetLinker, mtu uint32, closers *closers, logger Logger) ( - link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error, + linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error, ) { tun, err := tun.CreateTUN(interfaceName, int(mtu)) if err != nil { - return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) + return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) } closers.add("closing TUN device", stepSeven, tun.Close) tunName, err := tun.Name() if err != nil { - return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) + return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) } else if tunName != interfaceName { - return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", + return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", ErrCreateTun, interfaceName, tunName) } - link, err = netLinker.LinkByName(interfaceName) + link, err := netLinker.LinkByName(interfaceName) if err != nil { - return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) + return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) } closers.add("deleting link", stepFive, func() error { - return netLinker.LinkDel(link) + return netLinker.LinkDel(link.Index) }) bind := conn.NewDefaultBind() @@ -217,14 +220,14 @@ func setupUserSpace(ctx context.Context, uapiFile, err := uapiOpen(interfaceName) if err != nil { - return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) + return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) } closers.add("closing UAPI file", stepThree, uapiFile.Close) uapiListener, err := uapiListen(interfaceName, uapiFile) if err != nil { - return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) + return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) } closers.add("closing UAPI listener", stepTwo, uapiListener.Close) @@ -249,7 +252,7 @@ func setupUserSpace(ctx context.Context, return err } - return link, waitAndCleanup, nil + return link.Index, waitAndCleanup, nil } func acceptAndHandle(uapi net.Listener, device *device.Device, diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index 5b080613..66eee765 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -38,10 +38,10 @@ type Settings struct { FirewallMark uint32 // Maximum Transmission Unit (MTU) setting for the network interface. // It defaults to device.DefaultMTU from wireguard-go which is 1420 - MTU uint16 + MTU uint32 // RulePriority is the priority for the rule created with the // FirewallMark. - RulePriority int + RulePriority uint32 // IPv6 can bet set to true if IPv6 should be handled. // It defaults to false if left unset. IPv6 *bool