chore(all): replace netlink library for more flexibility (#3107)

This commit is contained in:
Quentin McGaw
2026-01-27 10:11:39 +01:00
committed by GitHub
parent e292a4c9be
commit facc6df3be
50 changed files with 1074 additions and 579 deletions
+1 -1
View File
@@ -59,7 +59,7 @@ jobs:
- name: Run tests in test container
run: |
touch coverage.txt
docker run --rm --device /dev/net/tun \
docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
+13 -12
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io/fs"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
@@ -549,26 +550,26 @@ type netLinker interface {
Router
Ruler
Linker
IsWireguardSupported() bool
IsWireguardSupported() (ok bool, err error)
IsIPv6Supported() (ok bool, err error)
PatchLoggerLevel(level log.Level)
}
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, addr netip.Prefix) error
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -576,12 +577,12 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu uint32) error
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) (err error)
LinkSetUp(linkIndex uint32) (err error)
LinkSetDown(linkIndex uint32) (err error)
LinkSetMTU(linkIndex, mtu uint32) error
}
type clier interface {
+11 -12
View File
@@ -7,8 +7,10 @@ require (
github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/gosettings v0.4.4
@@ -19,12 +21,11 @@ require (
github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1
github.com/ulikunitz/xz v0.5.15
github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
golang.org/x/text v0.31.0
golang.org/x/net v0.49.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0
@@ -38,13 +39,12 @@ require (
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -55,12 +55,11 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+22 -24
View File
@@ -13,6 +13,8 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
@@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
@@ -47,8 +51,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
@@ -93,10 +97,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -140,12 +140,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -155,8 +153,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -164,8 +162,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+2 -2
View File
@@ -46,7 +46,7 @@ type Wireguard struct {
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint16 `json:"mtu"`
MTU uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
@@ -273,7 +273,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err
}
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
if err != nil {
return err
} else if mtuPtr != nil {
+59 -15
View File
@@ -1,31 +1,75 @@
package netlink
import (
"github.com/vishvananda/netlink"
"fmt"
"net"
"net/netip"
"github.com/jsimonetti/rtnetlink/rtnl"
)
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error,
func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
ipPrefixes []netip.Prefix, err error,
) {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
conn, err := rtnl.Dial(nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ifc := &net.Interface{
Index: int(linkIndex),
}
ipNets, err := conn.Addrs(ifc, int(family))
if err != nil {
return nil, fmt.Errorf("failed to list addresses: %w", err)
}
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
ipPrefixes = make([]netip.Prefix, len(ipNets))
for i := range ipNets {
ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
}
return addresses, nil
return ipPrefixes, nil
}
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
conn, err := rtnl.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ipNet := netipPrefixToIPNet(prefix)
// Remove any address identical to the one we want to add
family := FamilyV4
if prefix.Addr().Is6() {
family = FamilyV6
}
ifc := &net.Interface{
Index: int(linkIndex),
}
addresses, err := conn.Addrs(ifc, int(family))
if err != nil {
return fmt.Errorf("listing addresses: %w", err)
}
for _, address := range addresses {
if address.IP.Equal(ipNet.IP) &&
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
err = conn.AddrDel(ifc, address)
if err != nil {
return fmt.Errorf("deleting address from interface: %w", err)
}
break
}
}
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
// Add the new address to the interface
err = conn.AddrAdd(ifc, ipNet)
if err != nil {
return fmt.Errorf("adding address to interface: %w", err)
}
return nil
}
+24
View File
@@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
return netip.PrefixFrom(ip, bits)
}
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
if ip == nil || len(*ip) == 0 {
return netip.Prefix{}
}
var dstIP netip.Addr
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
dstIP = netip.AddrFrom4([4]byte(*ip))
} else {
dstIP = netip.AddrFrom16([16]byte(*ip))
}
return netip.PrefixFrom(dstIP, int(length))
}
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
if !prefix.IsValid() {
return nil, 0
}
prefixIP := prefix.Addr().Unmap()
ip = new(net.IP)
*ip = netipAddrToNetIP(prefixIP)
length = uint8(prefix.Bits()) //nolint:gosec
return ip, length
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"fmt"
)
func FamilyToString(family int) string {
func FamilyToString(family uint8) string {
switch family {
case FamilyAll:
return "all"
+3 -3
View File
@@ -3,7 +3,7 @@ package netlink
import "golang.org/x/sys/unix"
const (
FamilyAll = unix.AF_UNSPEC
FamilyV4 = unix.AF_INET
FamilyV6 = unix.AF_INET6
FamilyAll uint8 = unix.AF_UNSPEC
FamilyV4 uint8 = unix.AF_INET
FamilyV6 uint8 = unix.AF_INET6
)
+14
View File
@@ -1,16 +1,30 @@
package netlink
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/log"
)
func ptrTo[T any](v T) *T { return &v }
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
name := make([]byte, 8)
for i := range name {
name[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(name)
}
type noopLogger struct{}
func (l *noopLogger) Debug(_ string) {}
+1 -1
View File
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
return false, fmt.Errorf("finding link corresponding to route: %w", err)
}
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch {
case !sourceIsIPv6 && !destinationIsIPv6,
+163 -79
View File
@@ -1,107 +1,191 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"errors"
"fmt"
func (n *NetLink) LinkList() (links []Link, err error) {
netlinkLinks, err := netlink.LinkList()
if err != nil {
return nil, err
"github.com/jsimonetti/rtnetlink"
)
type DeviceType uint16
type Link struct {
Index uint32
Name string
DeviceType DeviceType
VirtualType string
MTU uint32
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
func (n *NetLink) LinkList() (links []Link, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkMessages, err := conn.Link.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
}
links = make([]Link, len(linkMessages))
for i, message := range linkMessages {
virtualType := ""
if message.Attributes.Info != nil {
virtualType = message.Attributes.Info.Kind
}
links[i] = Link{
Index: message.Index,
Name: message.Attributes.Name,
DeviceType: DeviceType(message.Type),
VirtualType: virtualType,
MTU: message.Attributes.MTU,
}
}
return links, nil
}
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) {
netlinkLink, err := netlink.LinkByName(name)
links, err := n.LinkList()
if err != nil {
return Link{}, err
return Link{}, fmt.Errorf("listing links: %w", err)
}
return netlinkLinkToLink(netlinkLink), nil
for _, link := range links {
if link.Name == name {
return link, nil
}
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
netlinkLink, err := netlink.LinkByIndex(index)
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
}
func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
links, err := n.LinkList()
if err != nil {
return Link{}, err
return Link{}, fmt.Errorf("listing links: %w", err)
}
return netlinkLinkToLink(netlinkLink), nil
for _, link = range links {
if link.Index == index {
return link, nil
}
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
}
func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
return 0, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetMTU(link Link, mtu uint32) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), int(mtu))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU), //nolint:gosec
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{
tx := &rtnetlink.LinkMessage{
Type: uint16(link.DeviceType),
Attributes: &rtnetlink.LinkAttributes{
MTU: link.MTU,
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
},
}
if link.VirtualType != "" {
tx.Attributes.Info = &rtnetlink.LinkInfo{
Kind: link.VirtualType,
}
}
err = conn.Link.New(tx)
if err != nil {
return 0, fmt.Errorf("creating new link: %w", err)
}
linkMessages, err := conn.Link.List()
if err != nil {
return 0, fmt.Errorf("listing links: %w", err)
}
for _, linkMessage := range linkMessages {
if linkMessage.Attributes.Name == link.Name {
return linkMessage.Index, nil
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Link.Delete(linkIndex)
}
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
rx, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
tx := &rtnetlink.LinkMessage{
Type: rx.Type,
Index: linkIndex,
Flags: iffUp,
Change: iffUp,
}
return conn.Link.Set(tx)
}
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkInfo, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
message := &rtnetlink.LinkMessage{
Type: linkInfo.Type,
Index: linkIndex,
Flags: 0,
Change: iffUp,
}
return conn.Link.Set(message)
}
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
message := &rtnetlink.LinkMessage{
Index: linkIndex,
Attributes: &rtnetlink.LinkAttributes{
MTU: mtu,
},
}
err = conn.Link.Set(message)
if err != nil {
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
mtu, linkIndex, err)
}
return nil
}
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
iffUp = unix.IFF_UP
)
+85
View File
@@ -0,0 +1,85 @@
//go:build linux
package netlink
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NetLink_LinkList(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
initialLinks, err := netlink.LinkList()
require.NoError(t, err)
require.NotEmpty(t, initialLinks)
loopbackFound := false
for _, link := range initialLinks {
if link.Name != "lo" {
continue
}
loopbackFound = true
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
break
}
assert.True(t, loopbackFound, "loopback interface not found")
testLink := Link{
Name: makeLinkName(),
// note if [Link.VirtualType] is set, [Link.DeviceType]
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
links, err := netlink.LinkList()
require.NoError(t, err)
testLink.Index = index
for _, link := range links {
if link.Name != testLink.Name {
continue
}
assert.Equal(t, testLink, link)
return
}
t.Errorf("created link %q not found", testLink.Name)
}
func Test_NetLink_LinkSetMTU(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
testLink := Link{
Name: makeLinkName(),
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
testLink.Index = index
err = netlink.LinkSetMTU(index, 1500)
require.NoError(t, err)
link, err := netlink.LinkByIndex(index)
require.NoError(t, err)
testLink.MTU = 1500
assert.Equal(t, testLink, link)
}
+28 -3
View File
@@ -5,16 +5,41 @@ package netlink
const (
// FamilyAll is a placeholder only and should not
// be used.
FamilyAll = iota
FamilyAll uint8 = iota
// FamilyV4 is a placeholder only and should not
// be used.
FamilyV4
// FamilyV6 is a placeholder only and should not
// be used.
FamilyV6
// DeviceTypeEthernet is a placeholder only and should not be used.
DeviceTypeEthernet DeviceType = 0
// DeviceTypeLoopback is a placeholder only and should not be used.
DeviceTypeLoopback DeviceType = 0
// DeviceTypeNone is a placeholder only and should not be used.
DeviceTypeNone DeviceType = 0
// iffUp is a placeholder only and should not be used.
iffUp = 0
// RouteTypeUnicast is a placeholder only and should not be used.
RouteTypeUnicast = 0
// ScopeUniverse is a placeholder only and should not be used.
ScopeUniverse = 0
// ProtoStatic is a placeholder only and should not be used.
ProtoStatic = 0
// FlagInvert is a placeholder only and should not be used.
FlagInvert = 0
// ActionToTable is a placeholder only and should not be used.
ActionToTable = 0
// rtTableCompat is a placeholder only and should not be used.
rtTableCompat = 0
)
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
panic("not implemented")
}
@@ -26,6 +51,6 @@ func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented")
}
func (n *NetLink) IsWireguardSupported() bool {
func (n *NetLink) IsWireguardSupported() (bool, error) {
panic("not implemented")
}
+104 -46
View File
@@ -1,67 +1,125 @@
package netlink
import (
"github.com/vishvananda/netlink"
"fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
)
func (n *NetLink) RouteList(family int) (routes []Route, err error) {
// We set the filter to netlink.RT_FILTER_TABLE so that
// routes from all tables are listed, as long as the filter
// table is set to 0.
const filterMask = netlink.RT_FILTER_TABLE
// The filter is not left to `nil` otherwise non-main tables
// are ignored.
filter := &netlink.Route{}
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask)
if err != nil {
return nil, err
type Route struct {
LinkIndex uint32
Dst netip.Prefix
Src netip.Prefix
Gw netip.Addr
Priority uint32
Family uint8
Table uint32
Type uint8
Scope uint8
Proto uint8
}
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = message.Attributes.Table
}
r.LinkIndex = message.Attributes.OutIface
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
}
func (r Route) message() *rtnetlink.RouteMessage {
dst, dstLength := prefixToIPAndLength(r.Dst)
src, srcLength := prefixToIPAndLength(r.Src)
var table uint8
var extendedTable uint32
if r.Table <= uint32(^uint8(0)) {
table = uint8(r.Table)
} else {
table = rtTableCompat
extendedTable = r.Table
}
message := &rtnetlink.RouteMessage{
Family: r.Family,
DstLength: dstLength,
SrcLength: srcLength,
Table: table,
Type: r.Type,
Scope: r.Scope,
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Dst: *dst, // there should always be a dst for routes
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
},
}
if src != nil { // src is optional
message.Attributes.Src = *src
}
return message
}
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
routeMessages, err := conn.Route.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
}
routes = make([]Route, 0, len(routeMessages))
for _, routeMessage := range routeMessages {
if family != FamilyAll && routeMessage.Family != family {
continue
}
var route Route
route.fromMessage(routeMessage)
routes = append(routes, route)
}
return routes, nil
}
func (n *NetLink) RouteAdd(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Add(route.message())
}
func (n *NetLink) RouteDel(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Delete(route.message())
}
func (n *NetLink) RouteReplace(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
}
}
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
return conn.Route.Replace(route.message())
}
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
RouteTypeUnicast = unix.RTN_UNICAST
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
ProtoStatic = unix.RTPROT_STATIC
rtTableCompat = unix.RT_TABLE_COMPAT
)
+81 -40
View File
@@ -2,54 +2,95 @@ package netlink
import (
"fmt"
"net/netip"
"github.com/vishvananda/netlink"
"github.com/jsimonetti/rtnetlink"
)
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: 0,
}
type Rule struct {
Priority *uint32
Family uint8
Table uint32
Mark *uint32
Src netip.Prefix
Dst netip.Prefix
Flags uint32
Action uint8
}
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = *message.Attributes.Table
}
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Mark = message.Attributes.FwMark
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
r.Flags = message.Flags
r.Action = message.Action
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
}
func (r Rule) message() *rtnetlink.RuleMessage {
src, srcLength := prefixToIPAndLength(r.Src)
dst, dstLength := prefixToIPAndLength(r.Dst)
message := &rtnetlink.RuleMessage{
Family: r.Family,
SrcLength: srcLength,
DstLength: dstLength,
Flags: r.Flags,
Action: r.Action,
Attributes: &rtnetlink.RuleAttributes{
Priority: r.Priority,
FwMark: r.Mark,
Src: src,
Dst: dst,
},
}
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
if r.Table <= uint32(^uint8(0)) {
message.Table = uint8(r.Table)
} else {
message.Table = rtTableCompat
message.Attributes.Table = &r.Table
}
return message
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
priority := ""
if r.Priority != nil {
priority = fmt.Sprintf(" %d", *r.Priority)
}
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
priority, from, to, r.Table)
}
func (r Rule) debugMessage(add bool) (debugMessage string) {
debugMessage = "ip"
switch rule.Family {
switch r.Family {
case FamilyV4:
debugMessage += " -f inet"
case FamilyV6:
debugMessage += " -f inet6"
default:
debugMessage += " -f " + fmt.Sprint(rule.Family)
debugMessage += " -f " + fmt.Sprint(r.Family)
}
debugMessage += " rule"
@@ -60,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
debugMessage += " del"
}
if rule.Src.IsValid() {
debugMessage += " from " + rule.Src.String()
if r.Src.IsValid() {
debugMessage += " from " + r.Src.String()
}
if rule.Dst.IsValid() {
debugMessage += " to " + rule.Dst.String()
if r.Dst.IsValid() {
debugMessage += " to " + r.Dst.String()
}
if rule.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(rule.Table)
if r.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(r.Table)
}
if rule.Priority != -1 {
debugMessage += " pref " + fmt.Sprint(rule.Priority)
if r.Priority != nil {
debugMessage += " pref " + fmt.Sprint(*r.Priority)
}
return debugMessage
+44 -12
View File
@@ -1,8 +1,18 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"fmt"
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
"github.com/jsimonetti/rtnetlink"
"golang.org/x/sys/unix"
)
const (
FlagInvert = unix.FIB_RULE_INVERT
ActionToTable = unix.FR_ACT_TO_TBL
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
switch family {
case FamilyAll:
n.debugLogger.Debug("ip -4 rule list")
@@ -12,26 +22,48 @@ func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
case FamilyV6:
n.debugLogger.Debug("ip -6 rule list")
}
netlinkRules, err := netlink.RuleList(family)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ruleMessages, err := conn.Rule.List()
if err != nil {
return nil, err
}
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
rules = make([]Rule, 0, len(ruleMessages))
for _, message := range ruleMessages {
if family != FamilyAll && family != message.Family {
continue
}
var rule Rule
rule.fromMessage(message)
rules = append(rules, rule)
}
return rules, nil
}
func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(true, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
n.debugLogger.Debug(rule.debugMessage(true))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Add(rule.message())
}
func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(false, rule))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
n.debugLogger.Debug(rule.debugMessage(false))
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Delete(rule.message())
}
+5 -5
View File
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_ruleDbgMsg(t *testing.T) {
func Test_Rule_debugMessage(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) {
dbgMsg string
}{
"default values": {
dbgMsg: "ip -f 0 rule del pref 0",
dbgMsg: "ip -f 0 rule del",
},
"add rule": {
add: true,
@@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: 101,
Priority: ptrTo(uint32(101)),
},
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2),
Table: 100,
Priority: 101,
Priority: ptrTo(uint32(101)),
},
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
@@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule)
dbgMsg := testCase.rule.debugMessage(testCase.add)
assert.Equal(t, testCase.dbgMsg, dbgMsg)
})
-58
View File
@@ -1,58 +0,0 @@
package netlink
import (
"fmt"
"net/netip"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
}
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
type Rule struct {
Priority int
Family int
Table int
Mark uint32
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
+34 -11
View File
@@ -1,35 +1,58 @@
package netlink
import (
"errors"
"fmt"
"os"
"github.com/mdlayher/genetlink"
"github.com/qdm12/gluetun/internal/mod"
"github.com/vishvananda/netlink"
)
func (n *NetLink) IsWireguardSupported() bool {
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels.
ok := hasWireguardFamily()
if ok {
return true
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
} else if ok {
return true, nil
}
// Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err := mod.Probe("wireguard")
err = mod.Probe("wireguard")
if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false
return false, nil
}
n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module.
return hasWireguardFamily()
ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
}
return ok, nil
}
func hasWireguardFamily() bool {
_, err := netlink.GenlFamilyGet("wireguard")
return err == nil
func hasWireguardFamily() (ok bool, err error) {
conn, err := genetlink.Dial(nil)
if err != nil {
return false, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
_, err = conn.GetFamily("wireguard")
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("getting wireguard family: %w", err)
}
return true, nil
}
+4 -1
View File
@@ -4,6 +4,8 @@ package netlink
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
@@ -12,7 +14,8 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
netLink := &NetLink{
debugLogger: &noopLogger{},
}
ok := netLink.IsWireguardSupported()
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported")
} else {
+1 -1
View File
@@ -14,7 +14,7 @@ type Service interface {
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
}
type PortAllower interface {
+1 -1
View File
@@ -17,7 +17,7 @@ type PortAllower interface {
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error)
AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
}
type Logger interface {
+1 -1
View File
@@ -14,7 +14,7 @@ type DefaultRoute struct {
NetInterface string
Gateway netip.Addr
AssignedIP netip.Addr
Family int
Family uint8
}
func (d DefaultRoute) String() string {
+4 -4
View File
@@ -8,8 +8,8 @@ import (
)
const (
inboundTable = 200
inboundPriority = 100
inboundTable uint32 = 200
inboundPriority uint32 = 100
)
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
return nil
}
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
return nil
}
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP
bits := 32
+2 -2
View File
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
var errInterfaceIPNotFound = errors.New("IP address not found for interface")
func ipMatchesFamily(ip netip.Addr, family int) bool {
func ipMatchesFamily(ip netip.Addr, family uint8) bool {
return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6())
}
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
+3 -3
View File
@@ -26,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("listing links: %w", err)
}
localLinks := make(map[int]struct{})
localLinks := make(map[uint32]struct{})
for _, link := range links {
if link.EncapType != "ether" {
if link.DeviceType != netlink.DeviceTypeEthernet {
continue
}
@@ -95,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
// Local has higher priority then outbound(99) and inbound(100) as the
// local routes might be necessary to reach the outbound/inbound routes.
const localPriority = 98
const localPriority uint32 = 98
// Main table was setup correctly by Docker, just need to add rules to use it
src := netip.Prefix{}
+14 -14
View File
@@ -5,6 +5,7 @@
package routing
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrList mocks base method.
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) {
func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Addr)
ret0, _ := ret[0].([]netip.Prefix)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
}
// LinkByIndex mocks base method.
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) {
func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkByIndex", arg0)
ret0, _ := ret[0].(netlink.Link)
@@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
}
// RuleList mocks base method.
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) {
func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleList", arg0)
ret0, _ := ret[0].([]netlink.Rule)
+2 -2
View File
@@ -9,8 +9,8 @@ import (
)
const (
outboundTable = 199
outboundPriority = 99
outboundTable uint32 = 199
outboundPriority uint32 = 99
)
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
+17 -4
View File
@@ -9,25 +9,33 @@ import (
)
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int,
iface string, table uint32,
) error {
destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr)
r.logger.Debug("ip route replace " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(table))
" table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}
if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
@@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
}
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int,
iface string, table uint32,
) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr)
r.logger.Debug("ip route delete " + destinationStr +
" via " + gateway.String() +
" dev " + iface +
" table " + strconv.Itoa(table))
" table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err)
}
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Family: family,
Table: table,
}
if err := r.netLinker.RouteDel(route); err != nil {
+10 -10
View File
@@ -15,20 +15,20 @@ type NetLinker interface {
}
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrList(linkIndex uint32, family uint8) (
addresses []netip.Prefix, err error)
AddrReplace(linkIndex uint32, prefix netip.Prefix) error
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
@@ -36,11 +36,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(index uint32) (err error)
LinkSetUp(index uint32) (err error)
LinkSetDown(index uint32) (err error)
}
type Routing struct {
+39 -13
View File
@@ -7,12 +7,19 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
@@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error
return nil
}
// rulesAreEqual checks whether two rules are equal
// only according to src, dst, priority and table.
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table
}
@@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+30 -22
View File
@@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix {
}
func makeIPRule(src, dst netip.Prefix,
table, priority int,
table uint32, priority uint32,
) netlink.Rule {
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
family := netlink.FamilyV4
if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
family = netlink.FamilyV6
}
return netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
}
func Test_Routing_addIPRule(t *testing.T) {
@@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
table uint32
priority uint32
ruleList ruleListCall
ruleAdd ruleAddCall
err error
@@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
testCases := map[string]struct {
src netip.Prefix
dst netip.Prefix
table int
priority int
table uint32
priority uint32
ruleList ruleListCall
ruleDel ruleDelCall
err error
@@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
}
}
func ptrTo[T any](v T) *T { return &v }
func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
@@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Priority: ptrTo(uint32(999)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
Priority: ptrTo(uint32(100)),
Table: 102,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
},
@@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Priority: ptrTo(uint32(100)),
Table: 101,
},
equal: true,
+3 -4
View File
@@ -33,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
return route.Gw, nil
case route.Dst.IsSingleIP() &&
route.Dst.Addr().Compare(route.Src) == 0 &&
route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
route.Table == tableLocal: // Wireguard
route.Src = route.Src.Unmap()
if route.Src.Is6() {
if route.Src.Addr().Is6() {
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
}
bytes := route.Src.As4()
bytes := route.Src.Addr().As4()
// force last byte to 1 to get the VPN gateway IP
// This is not necessarily bullet proof but it seems to work.
bytes[3] = 1
+8 -8
View File
@@ -57,15 +57,15 @@ type Storage interface {
}
type NetLinker interface {
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router
Ruler
Linker
IsWireguardSupported() bool
IsWireguardSupported() (ok bool, err error)
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
@@ -77,11 +77,11 @@ type Ruler interface {
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu uint32) (err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(linkIndex uint32) error
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkSetMTU(linkIndex, mtu uint32) error
}
type DNSLoop interface {
+3 -3
View File
@@ -172,7 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
@@ -183,14 +183,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU = uint32(originalMTU)
vpnLinkMTU = originalMTU
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err)
default:
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
+6 -12
View File
@@ -3,26 +3,20 @@ package wireguard
import (
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addAddresses(link netlink.Link,
func (w *Wireguard) addAddresses(linkIndex uint32,
addresses []netip.Prefix,
) (err error) {
for _, ipNet := range addresses {
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
for _, address := range addresses {
if !*w.settings.IPv6 && address.Addr().Is6() {
continue
}
address := netlink.Addr{
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
err = w.netlink.AddrReplace(linkIndex, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Name)
return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, linkIndex)
}
}
+18 -19
View File
@@ -6,7 +6,6 @@ import (
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
linkIndex uint32
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
err error
}{
"success": {
link: netlink.Link{Type: "wireguard"},
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
AddrReplace(linkIndex, ipNetTwo).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
},
"first add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
},
"second add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
AddrReplace(linkIndex, ipNetTwo).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
},
"ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.link)
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
err := wg.addAddresses(testCase.link, testCase.addrs)
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
if testCase.err != nil {
require.Error(t, err)
+50
View File
@@ -1,3 +1,53 @@
package wireguard
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func ptrTo[T any](x T) *T { return &x }
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
b := make([]byte, 8)
for i := range b {
b[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(b)
}
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table &&
a.Family == b.Family &&
a.Flags == b.Flags &&
a.Action == b.Action &&
ptrsEqual(a.Mark, b.Mark)
}
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if !a.IsValid() || !b.IsValid() {
return false
}
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+34 -36
View File
@@ -1,4 +1,4 @@
//go:build netlink && linux
//go:build linux
package wireguard
@@ -10,13 +10,16 @@ import (
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
type noopDebugLogger struct{}
func (n noopDebugLogger) Debugf(format string, args ...any) {}
func (n noopDebugLogger) Patch(options ...log.Option) {}
func (n noopDebugLogger) Debug(_ string) {}
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
func (n noopDebugLogger) Info(_ string) {}
func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
link := netlink.Link{
Type: "bridge",
Name: "test_8081",
}
// Remove any previously created test interface from a crashed/panic
// test or test suite run.
err := netlinker.LinkDel(link)
if err != nil && err.Error() != "invalid argument" {
require.NoError(t, err)
DeviceType: netlink.DeviceTypeNone,
VirtualType: "bridge",
Name: makeLinkName(),
}
linkIndex, err := netlinker.LinkAdd(link)
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
link.Index = linkIndex
defer func() {
err = netlinker.LinkDel(link)
err = netlinker.LinkDel(linkIndex)
assert.NoError(t, err)
}()
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
}
const addIterations = 2 // initial + replace
for i := 0; i < addIterations; i++ {
err = wg.addAddresses(link, addresses)
for range addIterations {
err = wg.addAddresses(link.Index, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.Network)
assert.Equal(t, addresses[i], netlinkAddress.Network)
require.Equal(t, len(addresses), len(ipPrefixes))
for i, ipPrefix := range ipPrefixes {
assert.Equal(t, addresses[i], ipPrefix)
}
}
}
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{
netlink: netlinker,
logger: &noopDebugLogger{},
}
rulePriority := 10000
const firewallMark = 999
const family = unix.AF_INET // ipv4
// Unique combination for this test
const rulePriority uint32 = 10000
const firewallMark uint32 = 12345
const family = netlink.FamilyV4
cleanup, err := wg.addRule(rulePriority,
firewallMark, family)
require.NoError(t, err)
defer func() {
t.Cleanup(func() {
err := cleanup()
assert.NoError(t, err)
}()
})
rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err)
expectedRule := netlink.Rule{
Priority: ptrTo(rulePriority),
Family: netlink.FamilyV4,
Table: firewallMark,
Mark: ptrTo(firewallMark),
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
var rule netlink.Rule
var ruleFound bool
for _, rule = range rules {
if rule.Mark == firewallMark {
if rulesAreEqual(rule, expectedRule) {
ruleFound = true
break
}
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority,
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
_ = nilCleanup() // in case it succeeds
}
require.Error(t, err)
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
}
+12 -8
View File
@@ -1,19 +1,23 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router
Ruler
Linker
IsWireguardSupported() bool
IsWireguardSupported() (ok bool, err error)
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
@@ -23,10 +27,10 @@ type Ruler interface {
}
type Linker interface {
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkDel(linkIndex uint32) error
}
+13 -12
View File
@@ -5,6 +5,7 @@
package wireguard
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() bool {
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
+10 -7
View File
@@ -8,11 +8,11 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32,
) (err error) {
for _, dst := range destinations {
err = w.addRoute(link, dst, firewallMark)
err = w.addRoute(linkIndex, dst, firewallMark)
if err == nil {
continue
}
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil
}
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32,
) (err error) {
family := netlink.FamilyV4
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
family = netlink.FamilyV6
}
route := netlink.Route{
LinkIndex: link.Index,
LinkIndex: linkIndex,
Dst: dst,
Family: family,
Table: int(firewallMark),
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}
err = w.netlink.RouteAdd(route)
if err != nil {
return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w",
link.Name, dst, firewallMark, err)
"adding route for link with index %d, destination %s and table %d: %w",
linkIndex, dst, firewallMark, err)
}
return err
+8 -10
View File
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst netip.Prefix
expectedRoute netlink.Route
routeAddErr error
err error
}{
"success": {
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
},
"route add error": {
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
routeAddErr: errDummy,
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
},
}
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
if testCase.err != nil {
require.Error(t, err)
+10 -8
View File
@@ -7,15 +7,17 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
family int,
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family uint8,
) (cleanup func() error, err error) {
rule := netlink.NewRule()
rule.Invert = true
rule.Priority = rulePriority
rule.Mark = firewallMark
rule.Table = int(firewallMark)
rule.Family = family
rule := netlink.Rule{
Priority: &rulePriority,
Family: family,
Table: firewallMark,
Mark: &firewallMark,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
if err := w.netlink.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
+15 -13
View File
@@ -8,15 +8,14 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func Test_Wireguard_addRule(t *testing.T) {
t.Parallel()
const rulePriority = 987
const firewallMark = 456
const family = unix.AF_INET
const rulePriority uint32 = 987
const firewallMark uint32 = 456
const family = netlink.FamilyV4
errDummy := errors.New("dummy")
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
}{
"success": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
},
"rule add error": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleAddErr: errDummy,
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
},
"rule delete error": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
+38 -35
View File
@@ -14,6 +14,7 @@ import (
)
var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
@@ -32,7 +33,11 @@ var (
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported := w.netlink.IsWireguardSupported()
kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
setupFunction := setupUserSpace
switch w.settings.Implementation {
@@ -65,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger)
link, waitAndCleanup, err := setupFunction(ctx,
linkIndex, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
if err != nil {
waitError <- err
return
}
err = w.addAddresses(link, w.settings.Addresses)
err = w.addAddresses(linkIndex, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return
@@ -85,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
linkIndex, err := w.netlink.LinkSetUp(link)
err = w.netlink.LinkSetUp(linkIndex)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
return w.netlink.LinkSetDown(linkIndex)
})
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
@@ -131,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) {
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: mtu,
}
links, err := netLinker.LinkList()
if err != nil {
return link, nil, fmt.Errorf("listing links: %w", err)
return 0, nil, fmt.Errorf("listing links: %w", err)
}
// Cleanup any previous Wireguard interface with the same name
// See https://github.com/qdm12/gluetun/issues/1669
for _, link := range links {
if link.Type == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link)
if link.VirtualType == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link.Index)
if err != nil {
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
interfaceName, err)
}
}
}
linkIndex, err := netLinker.LinkAdd(link)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
link := netlink.Link{
VirtualType: "wireguard",
Name: interfaceName,
MTU: mtu,
}
linkIndex, err = netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
return netLinker.LinkDel(linkIndex)
})
waitAndCleanup = func() error {
@@ -172,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
return ctx.Err()
}
return link, waitAndCleanup, nil
return linkIndex, waitAndCleanup, nil
}
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) {
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
link, err := netLinker.LinkByName(interfaceName)
if err != nil {
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
return netLinker.LinkDel(link.Index)
})
bind := conn.NewDefaultBind()
@@ -217,14 +220,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := uapiOpen(interfaceName)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
@@ -249,7 +252,7 @@ func setupUserSpace(ctx context.Context,
return err
}
return link, waitAndCleanup, nil
return link.Index, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device *device.Device,
+2 -2
View File
@@ -38,10 +38,10 @@ type Settings struct {
FirewallMark uint32
// Maximum Transmission Unit (MTU) setting for the network interface.
// It defaults to device.DefaultMTU from wireguard-go which is 1420
MTU uint16
MTU uint32
// RulePriority is the priority for the rule created with the
// FirewallMark.
RulePriority int
RulePriority uint32
// IPv6 can bet set to true if IPv6 should be handled.
// It defaults to false if left unset.
IPv6 *bool