chore(all): replace netlink library for more flexibility (#3107)

This commit is contained in:
Quentin McGaw
2026-01-27 10:11:39 +01:00
committed by GitHub
parent e292a4c9be
commit facc6df3be
50 changed files with 1074 additions and 579 deletions
+1 -1
View File
@@ -59,7 +59,7 @@ jobs:
- name: Run tests in test container - name: Run tests in test container
run: | run: |
touch coverage.txt touch coverage.txt
docker run --rm --device /dev/net/tun \ docker run --rm --cap-add=NET_ADMIN --device /dev/net/tun \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \ -v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container test-container
+13 -12
View File
@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net/http" "net/http"
"net/netip"
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
@@ -549,26 +550,26 @@ type netLinker interface {
Router Router
Ruler Ruler
Linker Linker
IsWireguardSupported() bool IsWireguardSupported() (ok bool, err error)
IsIPv6Supported() (ok bool, err error) IsIPv6Supported() (ok bool, err error)
PatchLoggerLevel(level log.Level) PatchLoggerLevel(level log.Level)
} }
type Addresser interface { type Addresser interface {
AddrList(link netlink.Link, family int) ( AddrList(linkIndex uint32, family uint8) (
addresses []netlink.Addr, err error) addresses []netip.Prefix, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error AddrReplace(linkIndex uint32, addr netip.Prefix) error
} }
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error) RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
@@ -576,12 +577,12 @@ type Ruler interface {
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error) LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error) LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(link netlink.Link) (err error) LinkDel(linkIndex uint32) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(linkIndex uint32) (err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(linkIndex uint32) (err error)
LinkSetMTU(link netlink.Link, mtu uint32) error LinkSetMTU(linkIndex, mtu uint32) error
} }
type clier interface { type clier interface {
+11 -12
View File
@@ -7,8 +7,10 @@ require (
github.com/breml/rootcerts v0.3.3 github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0 github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/jsimonetti/rtnetlink v1.4.2
github.com/klauspost/compress v1.18.1 github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6 github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/pelletier/go-toml/v2 v2.2.4 github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc10 github.com/qdm12/dns/v2 v2.0.0-rc10
github.com/qdm12/gosettings v0.4.4 github.com/qdm12/gosettings v0.4.4
@@ -19,12 +21,11 @@ require (
github.com/qdm12/ss-server v0.6.0 github.com/qdm12/ss-server v0.6.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/ulikunitz/xz v0.5.15 github.com/ulikunitz/xz v0.5.15
github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.47.0 golang.org/x/net v0.49.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.40.0
golang.org/x/text v0.31.0 golang.org/x/text v0.33.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0 gopkg.in/ini.v1 v1.67.0
@@ -38,13 +39,12 @@ require (
github.com/cloudflare/circl v1.6.1 // indirect github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect github.com/mdlayher/socket v0.5.1 // indirect
github.com/miekg/dns v1.1.62 // indirect github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
@@ -55,12 +55,11 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/crypto v0.47.0 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/mod v0.29.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.38.0 // indirect golang.org/x/tools v0.40.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
+22 -24
View File
@@ -13,6 +13,8 @@ github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXB
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I= github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
@@ -26,10 +28,12 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
@@ -47,8 +51,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
@@ -93,10 +97,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -106,15 +106,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -122,14 +122,14 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -140,12 +140,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -155,8 +153,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -164,8 +162,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+2 -2
View File
@@ -46,7 +46,7 @@ type Wireguard struct {
// investigation in the issue: // investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533. // https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature. // Note this should now be replaced with the PMTUD feature.
MTU uint16 `json:"mtu"` MTU uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use. // Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace". // It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string // It defaults to "auto" and cannot be the empty string
@@ -273,7 +273,7 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
return err return err
} }
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU") mtuPtr, err := r.Uint32Ptr("WIREGUARD_MTU")
if err != nil { if err != nil {
return err return err
} else if mtuPtr != nil { } else if mtuPtr != nil {
+59 -15
View File
@@ -1,31 +1,75 @@
package netlink package netlink
import ( import (
"github.com/vishvananda/netlink" "fmt"
"net"
"net/netip"
"github.com/jsimonetti/rtnetlink/rtnl"
) )
func (n *NetLink) AddrList(link Link, family int) ( func (n *NetLink) AddrList(linkIndex uint32, family uint8) (
addresses []Addr, err error, ipPrefixes []netip.Prefix, err error,
) { ) {
netlinkLink := linkToNetlinkLink(&link) conn, err := rtnl.Dial(nil)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ifc := &net.Interface{
Index: int(linkIndex),
}
ipNets, err := conn.Addrs(ifc, int(family))
if err != nil {
return nil, fmt.Errorf("failed to list addresses: %w", err)
} }
addresses = make([]Addr, len(netlinkAddresses)) ipPrefixes = make([]netip.Prefix, len(ipNets))
for i := range netlinkAddresses { for i := range ipNets {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet) ipPrefixes[i] = netIPNetToNetipPrefix(ipNets[i])
} }
return addresses, nil return ipPrefixes, nil
} }
func (n *NetLink) AddrReplace(link Link, addr Addr) error { func (n *NetLink) AddrReplace(linkIndex uint32, prefix netip.Prefix) error {
netlinkLink := linkToNetlinkLink(&link) conn, err := rtnl.Dial(nil)
netlinkAddress := netlink.Addr{ if err != nil {
IPNet: netipPrefixToIPNet(addr.Network), return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ipNet := netipPrefixToIPNet(prefix)
// Remove any address identical to the one we want to add
family := FamilyV4
if prefix.Addr().Is6() {
family = FamilyV6
}
ifc := &net.Interface{
Index: int(linkIndex),
}
addresses, err := conn.Addrs(ifc, int(family))
if err != nil {
return fmt.Errorf("listing addresses: %w", err)
}
for _, address := range addresses {
if address.IP.Equal(ipNet.IP) &&
net.IP(address.Mask).String() == net.IP(ipNet.Mask).String() {
err = conn.AddrDel(ifc, address)
if err != nil {
return fmt.Errorf("deleting address from interface: %w", err)
}
break
}
} }
return netlink.AddrReplace(netlinkLink, &netlinkAddress) // Add the new address to the interface
err = conn.AddrAdd(ifc, ipNet)
if err != nil {
return fmt.Errorf("adding address to interface: %w", err)
}
return nil
} }
+24
View File
@@ -36,6 +36,30 @@ func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
return netip.PrefixFrom(ip, bits) return netip.PrefixFrom(ip, bits)
} }
func ipAndLengthToPrefix(ip *net.IP, length uint8) netip.Prefix {
if ip == nil || len(*ip) == 0 {
return netip.Prefix{}
}
var dstIP netip.Addr
if ipv4 := ip.To4(); ipv4 != nil { // IPv6
dstIP = netip.AddrFrom4([4]byte(*ip))
} else {
dstIP = netip.AddrFrom16([16]byte(*ip))
}
return netip.PrefixFrom(dstIP, int(length))
}
func prefixToIPAndLength(prefix netip.Prefix) (ip *net.IP, length uint8) {
if !prefix.IsValid() {
return nil, 0
}
prefixIP := prefix.Addr().Unmap()
ip = new(net.IP)
*ip = netipAddrToNetIP(prefixIP)
length = uint8(prefix.Bits()) //nolint:gosec
return ip, length
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) { func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch { switch {
case !address.IsValid(): case !address.IsValid():
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
) )
func FamilyToString(family int) string { func FamilyToString(family uint8) string {
switch family { switch family {
case FamilyAll: case FamilyAll:
return "all" return "all"
+3 -3
View File
@@ -3,7 +3,7 @@ package netlink
import "golang.org/x/sys/unix" import "golang.org/x/sys/unix"
const ( const (
FamilyAll = unix.AF_UNSPEC FamilyAll uint8 = unix.AF_UNSPEC
FamilyV4 = unix.AF_INET FamilyV4 uint8 = unix.AF_INET
FamilyV6 = unix.AF_INET6 FamilyV6 uint8 = unix.AF_INET6
) )
+14
View File
@@ -1,16 +1,30 @@
package netlink package netlink
import ( import (
"math/rand/v2"
"net/netip" "net/netip"
"github.com/qdm12/log" "github.com/qdm12/log"
) )
func ptrTo[T any](v T) *T { return &v }
func makeNetipPrefix(n byte) netip.Prefix { func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24 const bits = 24
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
} }
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
name := make([]byte, 8)
for i := range name {
name[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(name)
}
type noopLogger struct{} type noopLogger struct{}
func (l *noopLogger) Debug(_ string) {} func (l *noopLogger) Debug(_ string) {}
+1 -1
View File
@@ -19,7 +19,7 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
return false, fmt.Errorf("finding link corresponding to route: %w", err) return false, fmt.Errorf("finding link corresponding to route: %w", err)
} }
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6() sourceIsIPv6 := route.Src.Addr().IsValid() && route.Src.Addr().Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6() destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch { switch {
case !sourceIsIPv6 && !destinationIsIPv6, case !sourceIsIPv6 && !destinationIsIPv6,
+162 -78
View File
@@ -1,107 +1,191 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"errors"
"fmt"
"github.com/jsimonetti/rtnetlink"
)
type DeviceType uint16
type Link struct {
Index uint32
Name string
DeviceType DeviceType
VirtualType string
MTU uint32
}
func (n *NetLink) LinkList() (links []Link, err error) { func (n *NetLink) LinkList() (links []Link, err error) {
netlinkLinks, err := netlink.LinkList() conn, err := rtnetlink.Dial(nil)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkMessages, err := conn.Link.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
} }
links = make([]Link, len(netlinkLinks)) links = make([]Link, len(linkMessages))
for i := range netlinkLinks { for i, message := range linkMessages {
links[i] = netlinkLinkToLink(netlinkLinks[i]) virtualType := ""
if message.Attributes.Info != nil {
virtualType = message.Attributes.Info.Kind
}
links[i] = Link{
Index: message.Index,
Name: message.Attributes.Name,
DeviceType: DeviceType(message.Type),
VirtualType: virtualType,
MTU: message.Attributes.MTU,
}
} }
return links, nil return links, nil
} }
var ErrLinkNotFound = errors.New("link not found")
func (n *NetLink) LinkByName(name string) (link Link, err error) { func (n *NetLink) LinkByName(name string) (link Link, err error) {
netlinkLink, err := netlink.LinkByName(name) links, err := n.LinkList()
if err != nil { if err != nil {
return Link{}, err return Link{}, fmt.Errorf("listing links: %w", err)
} }
return netlinkLinkToLink(netlinkLink), nil for _, link := range links {
if link.Name == name {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for name %s", ErrLinkNotFound, name)
} }
func (n *NetLink) LinkByIndex(index int) (link Link, err error) { func (n *NetLink) LinkByIndex(index uint32) (link Link, err error) {
netlinkLink, err := netlink.LinkByIndex(index) links, err := n.LinkList()
if err != nil { if err != nil {
return Link{}, err return Link{}, fmt.Errorf("listing links: %w", err)
} }
return netlinkLinkToLink(netlinkLink), nil for _, link = range links {
if link.Index == index {
return link, nil
}
}
return Link{}, fmt.Errorf("%w: for index %d", ErrLinkNotFound, index)
} }
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) { func (n *NetLink) LinkAdd(link Link) (linkIndex uint32, err error) {
netlinkLink := linkToNetlinkLink(&link) conn, err := rtnetlink.Dial(nil)
err = netlink.LinkAdd(netlinkLink)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("dialing netlink: %w", err)
} }
return netlinkLink.Attrs().Index, nil defer conn.Close()
}
func (n *NetLink) LinkDel(link Link) (err error) { tx := &rtnetlink.LinkMessage{
return netlink.LinkDel(linkToNetlinkLink(&link)) Type: uint16(link.DeviceType),
} Attributes: &rtnetlink.LinkAttributes{
MTU: link.MTU,
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) { Name: link.Name,
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetMTU(link Link, mtu uint32) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), int(mtu))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU), //nolint:gosec
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
}, },
} }
if link.VirtualType != "" {
tx.Attributes.Info = &rtnetlink.LinkInfo{
Kind: link.VirtualType,
}
}
err = conn.Link.New(tx)
if err != nil {
return 0, fmt.Errorf("creating new link: %w", err)
}
linkMessages, err := conn.Link.List()
if err != nil {
return 0, fmt.Errorf("listing links: %w", err)
}
for _, linkMessage := range linkMessages {
if linkMessage.Attributes.Name == link.Name {
return linkMessage.Index, nil
}
}
return 0, fmt.Errorf("%w: matching name %s", ErrLinkNotFound, link.Name)
}
func (n *NetLink) LinkDel(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Link.Delete(linkIndex)
}
func (n *NetLink) LinkSetUp(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
rx, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
tx := &rtnetlink.LinkMessage{
Type: rx.Type,
Index: linkIndex,
Flags: iffUp,
Change: iffUp,
}
return conn.Link.Set(tx)
}
func (n *NetLink) LinkSetDown(linkIndex uint32) (err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
linkInfo, err := conn.Link.Get(linkIndex)
if err != nil {
return fmt.Errorf("getting link: %w", err)
}
message := &rtnetlink.LinkMessage{
Type: linkInfo.Type,
Index: linkIndex,
Flags: 0,
Change: iffUp,
}
return conn.Link.Set(message)
}
func (n *NetLink) LinkSetMTU(linkIndex, mtu uint32) error {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
message := &rtnetlink.LinkMessage{
Index: linkIndex,
Attributes: &rtnetlink.LinkAttributes{
MTU: mtu,
},
}
err = conn.Link.Set(message)
if err != nil {
return fmt.Errorf("setting MTU to %d for link at index %d: %w",
mtu, linkIndex, err)
}
return nil
} }
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
DeviceTypeEthernet DeviceType = unix.ARPHRD_ETHER
DeviceTypeLoopback DeviceType = unix.ARPHRD_LOOPBACK
DeviceTypeNone DeviceType = unix.ARPHRD_NONE
iffUp = unix.IFF_UP
)
+85
View File
@@ -0,0 +1,85 @@
//go:build linux
package netlink
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_NetLink_LinkList(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
initialLinks, err := netlink.LinkList()
require.NoError(t, err)
require.NotEmpty(t, initialLinks)
loopbackFound := false
for _, link := range initialLinks {
if link.Name != "lo" {
continue
}
loopbackFound = true
assert.Equal(t, DeviceTypeLoopback, link.DeviceType)
break
}
assert.True(t, loopbackFound, "loopback interface not found")
testLink := Link{
Name: makeLinkName(),
// note if [Link.VirtualType] is set, [Link.DeviceType]
// is ignored and gets set to [DeviceTypeNone] in LinkAdd.
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
links, err := netlink.LinkList()
require.NoError(t, err)
testLink.Index = index
for _, link := range links {
if link.Name != testLink.Name {
continue
}
assert.Equal(t, testLink, link)
return
}
t.Errorf("created link %q not found", testLink.Name)
}
func Test_NetLink_LinkSetMTU(t *testing.T) {
t.Parallel()
netlink := &NetLink{}
testLink := Link{
Name: makeLinkName(),
DeviceType: DeviceTypeNone,
VirtualType: "wireguard",
MTU: 1420,
}
index, err := netlink.LinkAdd(testLink)
require.NoError(t, err)
t.Cleanup(func() {
_ = netlink.LinkDel(index)
})
testLink.Index = index
err = netlink.LinkSetMTU(index, 1500)
require.NoError(t, err)
link, err := netlink.LinkByIndex(index)
require.NoError(t, err)
testLink.MTU = 1500
assert.Equal(t, testLink, link)
}
+28 -3
View File
@@ -5,16 +5,41 @@ package netlink
const ( const (
// FamilyAll is a placeholder only and should not // FamilyAll is a placeholder only and should not
// be used. // be used.
FamilyAll = iota FamilyAll uint8 = iota
// FamilyV4 is a placeholder only and should not // FamilyV4 is a placeholder only and should not
// be used. // be used.
FamilyV4 FamilyV4
// FamilyV6 is a placeholder only and should not // FamilyV6 is a placeholder only and should not
// be used. // be used.
FamilyV6 FamilyV6
// DeviceTypeEthernet is a placeholder only and should not be used.
DeviceTypeEthernet DeviceType = 0
// DeviceTypeLoopback is a placeholder only and should not be used.
DeviceTypeLoopback DeviceType = 0
// DeviceTypeNone is a placeholder only and should not be used.
DeviceTypeNone DeviceType = 0
// iffUp is a placeholder only and should not be used.
iffUp = 0
// RouteTypeUnicast is a placeholder only and should not be used.
RouteTypeUnicast = 0
// ScopeUniverse is a placeholder only and should not be used.
ScopeUniverse = 0
// ProtoStatic is a placeholder only and should not be used.
ProtoStatic = 0
// FlagInvert is a placeholder only and should not be used.
FlagInvert = 0
// ActionToTable is a placeholder only and should not be used.
ActionToTable = 0
// rtTableCompat is a placeholder only and should not be used.
rtTableCompat = 0
) )
func (n *NetLink) RuleList(family int) (rules []Rule, err error) { func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
panic("not implemented") panic("not implemented")
} }
@@ -26,6 +51,6 @@ func (n *NetLink) RuleDel(rule Rule) error {
panic("not implemented") panic("not implemented")
} }
func (n *NetLink) IsWireguardSupported() bool { func (n *NetLink) IsWireguardSupported() (bool, error) {
panic("not implemented") panic("not implemented")
} }
+102 -44
View File
@@ -1,67 +1,125 @@
package netlink package netlink
import ( import (
"github.com/vishvananda/netlink" "fmt"
"net/netip"
"github.com/jsimonetti/rtnetlink"
) )
func (n *NetLink) RouteList(family int) (routes []Route, err error) { type Route struct {
// We set the filter to netlink.RT_FILTER_TABLE so that LinkIndex uint32
// routes from all tables are listed, as long as the filter Dst netip.Prefix
// table is set to 0. Src netip.Prefix
const filterMask = netlink.RT_FILTER_TABLE Gw netip.Addr
// The filter is not left to `nil` otherwise non-main tables Priority uint32
// are ignored. Family uint8
filter := &netlink.Route{} Table uint32
Type uint8
Scope uint8
Proto uint8
}
netlinkRoutes, err := netlink.RouteListFiltered(family, filter, filterMask) func (r *Route) fromMessage(message rtnetlink.RouteMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = message.Attributes.Table
}
r.LinkIndex = message.Attributes.OutIface
r.Dst = ipAndLengthToPrefix(&message.Attributes.Dst, message.DstLength)
r.Src = ipAndLengthToPrefix(&message.Attributes.Src, message.SrcLength)
r.Gw = netIPToNetipAddress(message.Attributes.Gateway)
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Type = message.Type
r.Scope = message.Scope
r.Proto = message.Protocol
}
func (r Route) message() *rtnetlink.RouteMessage {
dst, dstLength := prefixToIPAndLength(r.Dst)
src, srcLength := prefixToIPAndLength(r.Src)
var table uint8
var extendedTable uint32
if r.Table <= uint32(^uint8(0)) {
table = uint8(r.Table)
} else {
table = rtTableCompat
extendedTable = r.Table
}
message := &rtnetlink.RouteMessage{
Family: r.Family,
DstLength: dstLength,
SrcLength: srcLength,
Table: table,
Type: r.Type,
Scope: r.Scope,
Protocol: r.Proto,
Attributes: rtnetlink.RouteAttributes{
OutIface: r.LinkIndex,
Dst: *dst, // there should always be a dst for routes
Gateway: netipAddrToNetIP(r.Gw),
Priority: r.Priority,
Table: extendedTable,
},
}
if src != nil { // src is optional
message.Attributes.Src = *src
}
return message
}
func (n *NetLink) RouteList(family uint8) (routes []Route, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
routeMessages, err := conn.Route.List()
if err != nil {
return nil, fmt.Errorf("listing interfaces: %w", err)
} }
routes = make([]Route, len(netlinkRoutes)) routes = make([]Route, 0, len(routeMessages))
for i := range netlinkRoutes { for _, routeMessage := range routeMessages {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i]) if family != FamilyAll && routeMessage.Family != family {
continue
}
var route Route
route.fromMessage(routeMessage)
routes = append(routes, route)
} }
return routes, nil return routes, nil
} }
func (n *NetLink) RouteAdd(route Route) error { func (n *NetLink) RouteAdd(route Route) error {
netlinkRoute := routeToNetlinkRoute(route) conn, err := rtnetlink.Dial(nil)
return netlink.RouteAdd(&netlinkRoute) if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Add(route.message())
} }
func (n *NetLink) RouteDel(route Route) error { func (n *NetLink) RouteDel(route Route) error {
netlinkRoute := routeToNetlinkRoute(route) conn, err := rtnetlink.Dial(nil)
return netlink.RouteDel(&netlinkRoute) if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Route.Delete(route.message())
} }
func (n *NetLink) RouteReplace(route Route) error { func (n *NetLink) RouteReplace(route Route) error {
netlinkRoute := routeToNetlinkRoute(route) conn, err := rtnetlink.Dial(nil)
return netlink.RouteReplace(&netlinkRoute) if err != nil {
} return fmt.Errorf("dialing netlink: %w", err)
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
} }
} defer conn.Close()
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) { return conn.Route.Replace(route.message())
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
} }
+11
View File
@@ -0,0 +1,11 @@
package netlink
import "golang.org/x/sys/unix"
const (
RouteTypeUnicast = unix.RTN_UNICAST
ScopeUniverse = unix.RT_SCOPE_UNIVERSE
ProtoStatic = unix.RTPROT_STATIC
rtTableCompat = unix.RT_TABLE_COMPAT
)
+80 -39
View File
@@ -2,54 +2,95 @@ package netlink
import ( import (
"fmt" "fmt"
"net/netip"
"github.com/vishvananda/netlink" "github.com/jsimonetti/rtnetlink"
) )
func NewRule() Rule { type Rule struct {
// defaults found from netlink.NewRule() for fields we use, Priority *uint32
// the rest of the defaults is set when converting from a `Rule` Family uint8
// to a `netlink.Rule` Table uint32
return Rule{ Mark *uint32
Priority: -1, Src netip.Prefix
Mark: 0, Dst netip.Prefix
Flags uint32
Action uint8
}
func (r *Rule) fromMessage(message rtnetlink.RuleMessage) {
table := uint32(message.Table)
if table == 0 || table == rtTableCompat {
table = *message.Attributes.Table
} }
r.Priority = message.Attributes.Priority
r.Family = message.Family
r.Table = table
r.Mark = message.Attributes.FwMark
r.Src = ipAndLengthToPrefix(message.Attributes.Src, message.SrcLength)
r.Dst = ipAndLengthToPrefix(message.Attributes.Dst, message.DstLength)
r.Flags = message.Flags
r.Action = message.Action
} }
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) { func (r Rule) message() *rtnetlink.RuleMessage {
netlinkRule = *netlink.NewRule() src, srcLength := prefixToIPAndLength(r.Src)
netlinkRule.Priority = rule.Priority dst, dstLength := prefixToIPAndLength(r.Dst)
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) { message := &rtnetlink.RuleMessage{
return Rule{ Family: r.Family,
Priority: netlinkRule.Priority, SrcLength: srcLength,
Family: netlinkRule.Family, DstLength: dstLength,
Table: netlinkRule.Table, Flags: r.Flags,
Mark: netlinkRule.Mark, Action: r.Action,
Src: netIPNetToNetipPrefix(netlinkRule.Src), Attributes: &rtnetlink.RuleAttributes{
Dst: netIPNetToNetipPrefix(netlinkRule.Dst), Priority: r.Priority,
Invert: netlinkRule.Invert, FwMark: r.Mark,
Src: src,
Dst: dst,
},
} }
if r.Table <= uint32(^uint8(0)) {
message.Table = uint8(r.Table)
} else {
message.Table = rtTableCompat
message.Attributes.Table = &r.Table
}
return message
} }
func ruleDbgMsg(add bool, rule Rule) (debugMessage string) { func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
priority := ""
if r.Priority != nil {
priority = fmt.Sprintf(" %d", *r.Priority)
}
return fmt.Sprintf("ip rule%s: from %s to %s table %d",
priority, from, to, r.Table)
}
func (r Rule) debugMessage(add bool) (debugMessage string) {
debugMessage = "ip" debugMessage = "ip"
switch rule.Family { switch r.Family {
case FamilyV4: case FamilyV4:
debugMessage += " -f inet" debugMessage += " -f inet"
case FamilyV6: case FamilyV6:
debugMessage += " -f inet6" debugMessage += " -f inet6"
default: default:
debugMessage += " -f " + fmt.Sprint(rule.Family) debugMessage += " -f " + fmt.Sprint(r.Family)
} }
debugMessage += " rule" debugMessage += " rule"
@@ -60,20 +101,20 @@ func ruleDbgMsg(add bool, rule Rule) (debugMessage string) {
debugMessage += " del" debugMessage += " del"
} }
if rule.Src.IsValid() { if r.Src.IsValid() {
debugMessage += " from " + rule.Src.String() debugMessage += " from " + r.Src.String()
} }
if rule.Dst.IsValid() { if r.Dst.IsValid() {
debugMessage += " to " + rule.Dst.String() debugMessage += " to " + r.Dst.String()
} }
if rule.Table != 0 { if r.Table != 0 {
debugMessage += " lookup " + fmt.Sprint(rule.Table) debugMessage += " lookup " + fmt.Sprint(r.Table)
} }
if rule.Priority != -1 { if r.Priority != nil {
debugMessage += " pref " + fmt.Sprint(rule.Priority) debugMessage += " pref " + fmt.Sprint(*r.Priority)
} }
return debugMessage return debugMessage
+44 -12
View File
@@ -1,8 +1,18 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"fmt"
func (n *NetLink) RuleList(family int) (rules []Rule, err error) { "github.com/jsimonetti/rtnetlink"
"golang.org/x/sys/unix"
)
const (
FlagInvert = unix.FIB_RULE_INVERT
ActionToTable = unix.FR_ACT_TO_TBL
)
func (n *NetLink) RuleList(family uint8) (rules []Rule, err error) {
switch family { switch family {
case FamilyAll: case FamilyAll:
n.debugLogger.Debug("ip -4 rule list") n.debugLogger.Debug("ip -4 rule list")
@@ -12,26 +22,48 @@ func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
case FamilyV6: case FamilyV6:
n.debugLogger.Debug("ip -6 rule list") n.debugLogger.Debug("ip -6 rule list")
} }
netlinkRules, err := netlink.RuleList(family)
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
ruleMessages, err := conn.Rule.List()
if err != nil { if err != nil {
return nil, err return nil, err
} }
rules = make([]Rule, len(netlinkRules)) rules = make([]Rule, 0, len(ruleMessages))
for i := range netlinkRules { for _, message := range ruleMessages {
rules[i] = netlinkRuleToRule(netlinkRules[i]) if family != FamilyAll && family != message.Family {
continue
}
var rule Rule
rule.fromMessage(message)
rules = append(rules, rule)
} }
return rules, nil return rules, nil
} }
func (n *NetLink) RuleAdd(rule Rule) error { func (n *NetLink) RuleAdd(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(true, rule)) n.debugLogger.Debug(rule.debugMessage(true))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule) conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Add(rule.message())
} }
func (n *NetLink) RuleDel(rule Rule) error { func (n *NetLink) RuleDel(rule Rule) error {
n.debugLogger.Debug(ruleDbgMsg(false, rule)) n.debugLogger.Debug(rule.debugMessage(false))
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule) conn, err := rtnetlink.Dial(nil)
if err != nil {
return fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
return conn.Rule.Delete(rule.message())
} }
+5 -5
View File
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_ruleDbgMsg(t *testing.T) { func Test_Rule_debugMessage(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
@@ -15,7 +15,7 @@ func Test_ruleDbgMsg(t *testing.T) {
dbgMsg string dbgMsg string
}{ }{
"default values": { "default values": {
dbgMsg: "ip -f 0 rule del pref 0", dbgMsg: "ip -f 0 rule del",
}, },
"add rule": { "add rule": {
add: true, add: true,
@@ -24,7 +24,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1), Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2), Dst: makeNetipPrefix(2),
Table: 100, Table: 100,
Priority: 101, Priority: ptrTo(uint32(101)),
}, },
dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", dbgMsg: "ip -f inet rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
}, },
@@ -34,7 +34,7 @@ func Test_ruleDbgMsg(t *testing.T) {
Src: makeNetipPrefix(1), Src: makeNetipPrefix(1),
Dst: makeNetipPrefix(2), Dst: makeNetipPrefix(2),
Table: 100, Table: 100,
Priority: 101, Priority: ptrTo(uint32(101)),
}, },
dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101", dbgMsg: "ip -f inet rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
}, },
@@ -44,7 +44,7 @@ func Test_ruleDbgMsg(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
dbgMsg := ruleDbgMsg(testCase.add, testCase.rule) dbgMsg := testCase.rule.debugMessage(testCase.add)
assert.Equal(t, testCase.dbgMsg, dbgMsg) assert.Equal(t, testCase.dbgMsg, dbgMsg)
}) })
-58
View File
@@ -1,58 +0,0 @@
package netlink
import (
"fmt"
"net/netip"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
}
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
type Rule struct {
Priority int
Family int
Table int
Mark uint32
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() && !r.Src.Addr().IsUnspecified() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() && !r.Dst.Addr().IsUnspecified() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
+34 -11
View File
@@ -1,35 +1,58 @@
package netlink package netlink
import ( import (
"errors"
"fmt"
"os"
"github.com/mdlayher/genetlink"
"github.com/qdm12/gluetun/internal/mod" "github.com/qdm12/gluetun/internal/mod"
"github.com/vishvananda/netlink"
) )
func (n *NetLink) IsWireguardSupported() bool { func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
// Check for Wireguard family without loading the wireguard module. // Check for Wireguard family without loading the wireguard module.
// Some kernels have the wireguard module built-in, and don't have a // Some kernels have the wireguard module built-in, and don't have a
// modules directory, such as WSL2 kernels. // modules directory, such as WSL2 kernels.
ok := hasWireguardFamily() ok, err = hasWireguardFamily()
if ok { if err != nil {
return true return false, fmt.Errorf("checking wireguard family: %w", err)
} else if ok {
return true, nil
} }
// Try loading the wireguard module, since some systems do not load // Try loading the wireguard module, since some systems do not load
// it after a boot. If this fails, wireguard is assumed to not be supported. // it after a boot. If this fails, wireguard is assumed to not be supported.
n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module") n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module")
err := mod.Probe("wireguard") err = mod.Probe("wireguard")
if err != nil { if err != nil {
n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err) n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err)
return false return false, nil
} }
n.debugLogger.Debugf("wireguard kernel module loaded successfully") n.debugLogger.Debugf("wireguard kernel module loaded successfully")
// Re-check if the Wireguard family is now available, after loading // Re-check if the Wireguard family is now available, after loading
// the wireguard kernel module. // the wireguard kernel module.
return hasWireguardFamily() ok, err = hasWireguardFamily()
if err != nil {
return false, fmt.Errorf("checking wireguard family: %w", err)
}
return ok, nil
} }
func hasWireguardFamily() bool { func hasWireguardFamily() (ok bool, err error) {
_, err := netlink.GenlFamilyGet("wireguard") conn, err := genetlink.Dial(nil)
return err == nil if err != nil {
return false, fmt.Errorf("dialing netlink: %w", err)
}
defer conn.Close()
_, err = conn.GetFamily("wireguard")
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("getting wireguard family: %w", err)
}
return true, nil
} }
+4 -1
View File
@@ -4,6 +4,8 @@ package netlink
import ( import (
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func Test_NetLink_IsWireguardSupported(t *testing.T) { func Test_NetLink_IsWireguardSupported(t *testing.T) {
@@ -12,7 +14,8 @@ func Test_NetLink_IsWireguardSupported(t *testing.T) {
netLink := &NetLink{ netLink := &NetLink{
debugLogger: &noopLogger{}, debugLogger: &noopLogger{},
} }
ok := netLink.IsWireguardSupported() ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
if ok { // cannot assert since this depends on kernel if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported") t.Log("wireguard is supported")
} else { } else {
+1 -1
View File
@@ -14,7 +14,7 @@ type Service interface {
type Routing interface { type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
} }
type PortAllower interface { type PortAllower interface {
+1 -1
View File
@@ -17,7 +17,7 @@ type PortAllower interface {
type Routing interface { type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error)
} }
type Logger interface { type Logger interface {
+1 -1
View File
@@ -14,7 +14,7 @@ type DefaultRoute struct {
NetInterface string NetInterface string
Gateway netip.Addr Gateway netip.Addr
AssignedIP netip.Addr AssignedIP netip.Addr
Family int Family uint8
} }
func (d DefaultRoute) String() string { func (d DefaultRoute) String() string {
+4 -4
View File
@@ -8,8 +8,8 @@ import (
) )
const ( const (
inboundTable = 200 inboundTable uint32 = 200
inboundPriority = 100 inboundPriority uint32 = 100
) )
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) { func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
@@ -60,7 +60,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
return nil return nil
} }
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { func (r *Routing) addRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes { for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP assignedIP := defaultRoute.AssignedIP
bits := 32 bits := 32
@@ -78,7 +78,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
return nil return nil
} }
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { func (r *Routing) delRuleInboundFromDefault(table uint32, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes { for _, defaultRoute := range defaultRoutes {
assignedIP := defaultRoute.AssignedIP assignedIP := defaultRoute.AssignedIP
bits := 32 bits := 32
+2 -2
View File
@@ -16,12 +16,12 @@ func ipIsPrivate(ip netip.Addr) bool {
var errInterfaceIPNotFound = errors.New("IP address not found for interface") var errInterfaceIPNotFound = errors.New("IP address not found for interface")
func ipMatchesFamily(ip netip.Addr, family int) bool { func ipMatchesFamily(ip netip.Addr, family uint8) bool {
return (family == netlink.FamilyV4 && ip.Is4()) || return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6()) (family == netlink.FamilyV6 && ip.Is6())
} }
func (r *Routing) AssignedIP(interfaceName string, family int) (ip netip.Addr, err error) { func (r *Routing) AssignedIP(interfaceName string, family uint8) (ip netip.Addr, err error) {
iface, err := net.InterfaceByName(interfaceName) iface, err := net.InterfaceByName(interfaceName)
if err != nil { if err != nil {
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err) return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
+3 -3
View File
@@ -26,10 +26,10 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("listing links: %w", err) return localNetworks, fmt.Errorf("listing links: %w", err)
} }
localLinks := make(map[int]struct{}) localLinks := make(map[uint32]struct{})
for _, link := range links { for _, link := range links {
if link.EncapType != "ether" { if link.DeviceType != netlink.DeviceTypeEthernet {
continue continue
} }
@@ -95,7 +95,7 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
// Local has higher priority then outbound(99) and inbound(100) as the // Local has higher priority then outbound(99) and inbound(100) as the
// local routes might be necessary to reach the outbound/inbound routes. // local routes might be necessary to reach the outbound/inbound routes.
const localPriority = 98 const localPriority uint32 = 98
// Main table was setup correctly by Docker, just need to add rules to use it // Main table was setup correctly by Docker, just need to add rules to use it
src := netip.Prefix{} src := netip.Prefix{}
+14 -14
View File
@@ -5,6 +5,7 @@
package routing package routing
import ( import (
netip "net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,10 +36,10 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
} }
// AddrList mocks base method. // AddrList mocks base method.
func (m *MockNetLinker) AddrList(arg0 netlink.Link, arg1 int) ([]netlink.Addr, error) { func (m *MockNetLinker) AddrList(arg0 uint32, arg1 byte) ([]netip.Prefix, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrList", arg0, arg1) ret := m.ctrl.Call(m, "AddrList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Addr) ret0, _ := ret[0].([]netip.Prefix)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -50,7 +51,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
} }
// AddrReplace mocks base method. // AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -64,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
} }
// LinkAdd mocks base method. // LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0) ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int) ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -79,7 +80,7 @@ func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
} }
// LinkByIndex mocks base method. // LinkByIndex mocks base method.
func (m *MockNetLinker) LinkByIndex(arg0 int) (netlink.Link, error) { func (m *MockNetLinker) LinkByIndex(arg0 uint32) (netlink.Link, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkByIndex", arg0) ret := m.ctrl.Call(m, "LinkByIndex", arg0)
ret0, _ := ret[0].(netlink.Link) ret0, _ := ret[0].(netlink.Link)
@@ -109,7 +110,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
} }
// LinkDel mocks base method. // LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error { func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0) ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -138,7 +139,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
} }
// LinkSetDown mocks base method. // LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0) ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -152,12 +153,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
} }
// LinkSetUp mocks base method. // LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0) ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// LinkSetUp indicates an expected call of LinkSetUp. // LinkSetUp indicates an expected call of LinkSetUp.
@@ -195,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
} }
// RouteList mocks base method. // RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) { func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0) ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route) ret0, _ := ret[0].([]netlink.Route)
@@ -252,7 +252,7 @@ func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
} }
// RuleList mocks base method. // RuleList mocks base method.
func (m *MockNetLinker) RuleList(arg0 int) ([]netlink.Rule, error) { func (m *MockNetLinker) RuleList(arg0 byte) ([]netlink.Rule, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleList", arg0) ret := m.ctrl.Call(m, "RuleList", arg0)
ret0, _ := ret[0].([]netlink.Rule) ret0, _ := ret[0].([]netlink.Rule)
+2 -2
View File
@@ -9,8 +9,8 @@ import (
) )
const ( const (
outboundTable = 199 outboundTable uint32 = 199
outboundPriority = 99 outboundPriority uint32 = 99
) )
func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error { func (r *Routing) SetOutboundRoutes(outboundSubnets []netip.Prefix) error {
+17 -4
View File
@@ -9,25 +9,33 @@ import (
) )
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr, func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int, iface string, table uint32,
) error { ) error {
destinationStr := destination.String() destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr) r.logger.Info("adding route for " + destinationStr)
r.logger.Debug("ip route replace " + destinationStr + r.logger.Debug("ip route replace " + destinationStr +
" via " + gateway.String() + " via " + gateway.String() +
" dev " + iface + " dev " + iface +
" table " + strconv.Itoa(table)) " table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface) link, err := r.netLinker.LinkByName(iface)
if err != nil { if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err) return fmt.Errorf("finding link for interface %s: %w", iface, err)
} }
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{ route := netlink.Route{
Dst: destination, Dst: destination,
Gw: gateway, Gw: gateway,
LinkIndex: link.Index, LinkIndex: link.Index,
Family: family,
Table: table, Table: table,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
} }
if err := r.netLinker.RouteReplace(route); err != nil { if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w", return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
@@ -38,24 +46,29 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
} }
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr, func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int, iface string, table uint32,
) (err error) { ) (err error) {
destinationStr := destination.String() destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr) r.logger.Info("deleting route for " + destinationStr)
r.logger.Debug("ip route delete " + destinationStr + r.logger.Debug("ip route delete " + destinationStr +
" via " + gateway.String() + " via " + gateway.String() +
" dev " + iface + " dev " + iface +
" table " + strconv.Itoa(table)) " table " + strconv.Itoa(int(table)))
link, err := r.netLinker.LinkByName(iface) link, err := r.netLinker.LinkByName(iface)
if err != nil { if err != nil {
return fmt.Errorf("finding link for interface %s: %w", iface, err) return fmt.Errorf("finding link for interface %s: %w", iface, err)
} }
family := netlink.FamilyV4
if destination.Addr().Is6() {
family = netlink.FamilyV6
}
route := netlink.Route{ route := netlink.Route{
Dst: destination, Dst: destination,
Gw: gateway, Gw: gateway,
LinkIndex: link.Index, LinkIndex: link.Index,
Family: family,
Table: table, Table: table,
} }
if err := r.netLinker.RouteDel(route); err != nil { if err := r.netLinker.RouteDel(route); err != nil {
+10 -10
View File
@@ -15,20 +15,20 @@ type NetLinker interface {
} }
type Addresser interface { type Addresser interface {
AddrList(link netlink.Link, family int) ( AddrList(linkIndex uint32, family uint8) (
addresses []netlink.Addr, err error) addresses []netip.Prefix, err error)
AddrReplace(link netlink.Link, addr netlink.Addr) error AddrReplace(linkIndex uint32, prefix netip.Prefix) error
} }
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error) RuleList(family uint8) (rules []netlink.Rule, err error)
RuleAdd(rule netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
@@ -36,11 +36,11 @@ type Ruler interface {
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error) LinkByIndex(index uint32) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error) LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(link netlink.Link) (err error) LinkDel(index uint32) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(index uint32) (err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(index uint32) (err error)
} }
type Routing struct { type Routing struct {
+39 -13
View File
@@ -7,12 +7,19 @@ import (
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error { func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority uint32) error {
rule := netlink.NewRule() family := netlink.FamilyV4
rule.Src = src if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
rule.Dst = dst family = netlink.FamilyV6
rule.Priority = priority }
rule.Table = table rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil { if err != nil {
@@ -31,12 +38,19 @@ func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
return nil return nil
} }
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error { func (r *Routing) deleteIPRule(src, dst netip.Prefix, table uint32, priority uint32) error {
rule := netlink.NewRule() family := netlink.FamilyV4
rule.Src = src if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
rule.Dst = dst family = netlink.FamilyV6
rule.Priority = priority }
rule.Table = table rule := netlink.Rule{
Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil { if err != nil {
@@ -53,10 +67,12 @@ func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error
return nil return nil
} }
// rulesAreEqual checks whether two rules are equal
// only according to src, dst, priority and table.
func rulesAreEqual(a, b netlink.Rule) bool { func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) && return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) && ipPrefixesAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority && ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table a.Table == b.Table
} }
@@ -70,3 +86,13 @@ func ipPrefixesAreEqual(a, b netip.Prefix) bool {
return a.Bits() == b.Bits() && return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0 a.Addr().Compare(b.Addr()) == 0
} }
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+30 -22
View File
@@ -17,14 +17,20 @@ func makeNetipPrefix(n byte) netip.Prefix {
} }
func makeIPRule(src, dst netip.Prefix, func makeIPRule(src, dst netip.Prefix,
table, priority int, table uint32, priority uint32,
) netlink.Rule { ) netlink.Rule {
rule := netlink.NewRule() family := netlink.FamilyV4
rule.Src = src if (src.IsValid() && src.Addr().Is6()) || (dst.IsValid() && dst.Addr().Is6()) {
rule.Dst = dst family = netlink.FamilyV6
rule.Table = table }
rule.Priority = priority return netlink.Rule{
return rule Priority: &priority,
Family: family,
Table: table,
Src: src,
Dst: dst,
Action: netlink.ActionToTable,
}
} }
func Test_Routing_addIPRule(t *testing.T) { func Test_Routing_addIPRule(t *testing.T) {
@@ -46,8 +52,8 @@ func Test_Routing_addIPRule(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
src netip.Prefix src netip.Prefix
dst netip.Prefix dst netip.Prefix
table int table uint32
priority int priority uint32
ruleList ruleListCall ruleList ruleListCall
ruleAdd ruleAddCall ruleAdd ruleAddCall
err error err error
@@ -149,8 +155,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
src netip.Prefix src netip.Prefix
dst netip.Prefix dst netip.Prefix
table int table uint32
priority int priority uint32
ruleList ruleListCall ruleList ruleListCall
ruleDel ruleDelCall ruleDel ruleDelCall
err error err error
@@ -238,6 +244,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
} }
} }
func ptrTo[T any](v T) *T { return &v }
func Test_rulesAreEqual(t *testing.T) { func Test_rulesAreEqual(t *testing.T) {
t.Parallel() t.Parallel()
@@ -253,13 +261,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{ a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
b: netlink.Rule{ b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
}, },
@@ -267,13 +275,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{ a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
b: netlink.Rule{ b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
}, },
@@ -281,13 +289,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{ a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999, Priority: ptrTo(uint32(999)),
Table: 101, Table: 101,
}, },
b: netlink.Rule{ b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
}, },
@@ -295,13 +303,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{ a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 999, Table: 102,
}, },
b: netlink.Rule{ b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
}, },
@@ -309,13 +317,13 @@ func Test_rulesAreEqual(t *testing.T) {
a: netlink.Rule{ a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
b: netlink.Rule{ b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: ptrTo(uint32(100)),
Table: 101, Table: 101,
}, },
equal: true, equal: true,
+3 -4
View File
@@ -33,13 +33,12 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN case route.Dst.IsValid() && route.Dst.Addr().IsUnspecified() && route.Gw.IsValid(): // OpenVPN
return route.Gw, nil return route.Gw, nil
case route.Dst.IsSingleIP() && case route.Dst.IsSingleIP() &&
route.Dst.Addr().Compare(route.Src) == 0 && route.Dst.Addr().Compare(route.Src.Addr()) == 0 &&
route.Table == tableLocal: // Wireguard route.Table == tableLocal: // Wireguard
route.Src = route.Src.Unmap() if route.Src.Addr().Is6() {
if route.Src.Is6() {
return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src) return netip.Addr{}, fmt.Errorf("%w: %s", ErrVPNLocalGatewayIPv6NotSupported, route.Src)
} }
bytes := route.Src.As4() bytes := route.Src.Addr().As4()
// force last byte to 1 to get the VPN gateway IP // force last byte to 1 to get the VPN gateway IP
// This is not necessarily bullet proof but it seems to work. // This is not necessarily bullet proof but it seems to work.
bytes[3] = 1 bytes[3] = 1
+8 -8
View File
@@ -57,15 +57,15 @@ type Storage interface {
} }
type NetLinker interface { type NetLinker interface {
AddrReplace(link netlink.Link, addr netlink.Addr) error AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router Router
Ruler Ruler
Linker Linker
IsWireguardSupported() bool IsWireguardSupported() (ok bool, err error)
} }
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
} }
@@ -77,11 +77,11 @@ type Ruler interface {
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (linkIndex int, err error) LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkDel(link netlink.Link) (err error) LinkDel(linkIndex uint32) error
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(linkIndex uint32) error
LinkSetDown(link netlink.Link) (err error) LinkSetDown(linkIndex uint32) error
LinkSetMTU(link netlink.Link, mtu uint32) (err error) LinkSetMTU(linkIndex, mtu uint32) error
} }
type DNSLoop interface { type DNSLoop interface {
+3 -3
View File
@@ -172,7 +172,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
// the new MTU is set again, but this is necessary to find the highest valid MTU. // the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU) logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU) err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil { if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
} }
@@ -183,14 +183,14 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
case err == nil: case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted): case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU = uint32(originalMTU) vpnLinkMTU = originalMTU
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)", logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err) vpnInterface, originalMTU, err)
default: default:
return fmt.Errorf("path MTU discovering: %w", err) return fmt.Errorf("path MTU discovering: %w", err)
} }
err = netlinker.LinkSetMTU(link, vpnLinkMTU) err = netlinker.LinkSetMTU(link.Index, vpnLinkMTU)
if err != nil { if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
} }
+6 -12
View File
@@ -3,26 +3,20 @@ package wireguard
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/qdm12/gluetun/internal/netlink"
) )
func (w *Wireguard) addAddresses(link netlink.Link, func (w *Wireguard) addAddresses(linkIndex uint32,
addresses []netip.Prefix, addresses []netip.Prefix,
) (err error) { ) (err error) {
for _, ipNet := range addresses { for _, address := range addresses {
if !*w.settings.IPv6 && ipNet.Addr().Is6() { if !*w.settings.IPv6 && address.Addr().Is6() {
continue continue
} }
address := netlink.Addr{ err = w.netlink.AddrReplace(linkIndex, address)
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
if err != nil { if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s", return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, link.Name) err, address, linkIndex)
} }
} }
+21 -22
View File
@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
link netlink.Link linkIndex uint32
addrs []netip.Prefix addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
err error err error
}{ }{
"success": { "success": {
link: netlink.Link{Type: "wireguard"}, linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT(). firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}). AddrReplace(linkIndex, ipNetOne).
Return(nil) Return(nil)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}). AddrReplace(linkIndex, ipNetTwo).
Return(nil).After(firstCall) Return(nil).After(firstCall)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
}, },
}, },
"first add error": { "first add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}). AddrReplace(linkIndex, ipNetOne).
Return(errDummy) Return(errDummy)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
}, },
} }
}, },
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
}, },
"second add error": { "second add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT(). firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}). AddrReplace(linkIndex, ipNetOne).
Return(nil) Return(nil)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}). AddrReplace(linkIndex, ipNetTwo).
Return(errDummy).After(firstCall) Return(errDummy).After(firstCall)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
}, },
} }
}, },
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"), err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
}, },
"ignore IPv6": { "ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo}, addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard { wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
return &Wireguard{ return &Wireguard{
settings: Settings{ settings: Settings{
IPv6: ptrTo(false), IPv6: ptrTo(false),
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.link) wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
err := wg.addAddresses(testCase.link, testCase.addrs) err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
if testCase.err != nil { if testCase.err != nil {
require.Error(t, err) require.Error(t, err)
+50
View File
@@ -1,3 +1,53 @@
package wireguard package wireguard
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func ptrTo[T any](x T) *T { return &x } func ptrTo[T any](x T) *T { return &x }
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
b := make([]byte, 8)
for i := range b {
b[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(b)
}
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table &&
a.Family == b.Family &&
a.Flags == b.Flags &&
a.Action == b.Action &&
ptrsEqual(a.Mark, b.Mark)
}
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if !a.IsValid() || !b.IsValid() {
return false
}
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+34 -36
View File
@@ -1,4 +1,4 @@
//go:build netlink && linux //go:build linux
package wireguard package wireguard
@@ -10,13 +10,16 @@ import (
"github.com/qdm12/log" "github.com/qdm12/log"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
type noopDebugLogger struct{} type noopDebugLogger struct{}
func (n noopDebugLogger) Debugf(format string, args ...any) {} func (n noopDebugLogger) Debug(_ string) {}
func (n noopDebugLogger) Patch(options ...log.Option) {} func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
func (n noopDebugLogger) Info(_ string) {}
func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) { func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel() t.Parallel()
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{}) netlinker := netlink.New(&noopDebugLogger{})
link := netlink.Link{ link := netlink.Link{
Type: "bridge", DeviceType: netlink.DeviceTypeNone,
Name: "test_8081", VirtualType: "bridge",
} Name: makeLinkName(),
// Remove any previously created test interface from a crashed/panic
// test or test suite run.
err := netlinker.LinkDel(link)
if err != nil && err.Error() != "invalid argument" {
require.NoError(t, err)
} }
linkIndex, err := netlinker.LinkAdd(link) linkIndex, err := netlinker.LinkAdd(link)
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
link.Index = linkIndex link.Index = linkIndex
defer func() { defer func() {
err = netlinker.LinkDel(link) err = netlinker.LinkDel(linkIndex)
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
} }
const addIterations = 2 // initial + replace const addIterations = 2 // initial + replace
for range addIterations {
for i := 0; i < addIterations; i++ { err = wg.addAddresses(link.Index, addresses)
err = wg.addAddresses(link, addresses)
require.NoError(t, err) require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll) ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses)) require.Equal(t, len(addresses), len(ipPrefixes))
for i, netlinkAddress := range netlinkAddresses { for i, ipPrefix := range ipPrefixes {
require.NotNil(t, netlinkAddress.Network) assert.Equal(t, addresses[i], ipPrefix)
assert.Equal(t, addresses[i], netlinkAddress.Network)
} }
} }
} }
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{}) netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{ wg := &Wireguard{
netlink: netlinker, netlink: netlinker,
logger: &noopDebugLogger{},
} }
rulePriority := 10000 // Unique combination for this test
const firewallMark = 999 const rulePriority uint32 = 10000
const family = unix.AF_INET // ipv4 const firewallMark uint32 = 12345
const family = netlink.FamilyV4
cleanup, err := wg.addRule(rulePriority, cleanup, err := wg.addRule(rulePriority,
firewallMark, family) firewallMark, family)
require.NoError(t, err) require.NoError(t, err)
defer func() { t.Cleanup(func() {
err := cleanup() err := cleanup()
assert.NoError(t, err) assert.NoError(t, err)
}() })
rules, err := netlinker.RuleList(netlink.FamilyV4) rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err) require.NoError(t, err)
expectedRule := netlink.Rule{
Priority: ptrTo(rulePriority),
Family: netlink.FamilyV4,
Table: firewallMark,
Mark: ptrTo(firewallMark),
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
var rule netlink.Rule var rule netlink.Rule
var ruleFound bool var ruleFound bool
for _, rule = range rules { for _, rule = range rules {
if rule.Mark == firewallMark { if rulesAreEqual(rule, expectedRule) {
ruleFound = true ruleFound = true
break break
} }
} }
require.True(t, ruleFound) require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added // Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority, nilCleanup, err := wg.addRule(rulePriority,
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
_ = nilCleanup() // in case it succeeds _ = nilCleanup() // in case it succeeds
} }
require.Error(t, err) require.Error(t, err)
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists") assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
} }
+12 -8
View File
@@ -1,19 +1,23 @@
package wireguard package wireguard
import "github.com/qdm12/gluetun/internal/netlink" import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface { type NetLinker interface {
AddrReplace(link netlink.Link, addr netlink.Addr) error AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router Router
Ruler Ruler
Linker Linker
IsWireguardSupported() bool IsWireguardSupported() (ok bool, err error)
} }
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
} }
@@ -23,10 +27,10 @@ type Ruler interface {
} }
type Linker interface { type Linker interface {
LinkAdd(link netlink.Link) (linkIndex int, err error) LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(linkIndex uint32) error
LinkSetDown(link netlink.Link) error LinkSetDown(linkIndex uint32) error
LinkDel(link netlink.Link) error LinkDel(linkIndex uint32) error
} }
+13 -12
View File
@@ -5,6 +5,7 @@
package wireguard package wireguard
import ( import (
netip "net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
} }
// AddrReplace mocks base method. // AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
} }
// IsWireguardSupported mocks base method. // IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() bool { func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported") ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// IsWireguardSupported indicates an expected call of IsWireguardSupported. // IsWireguardSupported indicates an expected call of IsWireguardSupported.
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
} }
// LinkAdd mocks base method. // LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0) ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int) ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
} }
// LinkDel mocks base method. // LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error { func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0) ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
} }
// LinkSetDown mocks base method. // LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0) ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
} }
// LinkSetUp mocks base method. // LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0) ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// LinkSetUp indicates an expected call of LinkSetUp. // LinkSetUp indicates an expected call of LinkSetUp.
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
} }
// RouteList mocks base method. // RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) { func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0) ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route) ret0, _ := ret[0].([]netlink.Route)
+10 -7
View File
@@ -8,11 +8,11 @@ import (
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix, func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32, firewallMark uint32,
) (err error) { ) (err error) {
for _, dst := range destinations { for _, dst := range destinations {
err = w.addRoute(link, dst, firewallMark) err = w.addRoute(linkIndex, dst, firewallMark)
if err == nil { if err == nil {
continue continue
} }
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil return nil
} }
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32, firewallMark uint32,
) (err error) { ) (err error) {
family := netlink.FamilyV4 family := netlink.FamilyV4
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
family = netlink.FamilyV6 family = netlink.FamilyV6
} }
route := netlink.Route{ route := netlink.Route{
LinkIndex: link.Index, LinkIndex: linkIndex,
Dst: dst, Dst: dst,
Family: family, Family: family,
Table: int(firewallMark), Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
} }
err = w.netlink.RouteAdd(route) err = w.netlink.RouteAdd(route)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w", "adding route for link with index %d, destination %s and table %d: %w",
link.Name, dst, firewallMark, err) linkIndex, dst, firewallMark, err)
} }
return err return err
+8 -10
View File
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
link netlink.Link
dst netip.Prefix dst netip.Prefix
expectedRoute netlink.Route expectedRoute netlink.Route
routeAddErr error routeAddErr error
err error err error
}{ }{
"success": { "success": {
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix, dst: ipPrefix,
expectedRoute: netlink.Route{ expectedRoute: netlink.Route{
LinkIndex: linkIndex, LinkIndex: linkIndex,
Dst: ipPrefix, Dst: ipPrefix,
Family: netlink.FamilyV4, Family: netlink.FamilyV4,
Table: firewallMark, Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}, },
}, },
"route add error": { "route add error": {
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix, dst: ipPrefix,
expectedRoute: netlink.Route{ expectedRoute: netlink.Route{
LinkIndex: linkIndex, LinkIndex: linkIndex,
Dst: ipPrefix, Dst: ipPrefix,
Family: netlink.FamilyV4, Family: netlink.FamilyV4,
Table: firewallMark, Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}, },
routeAddErr: errDummy, routeAddErr: errDummy,
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
}, },
} }
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
RouteAdd(testCase.expectedRoute). RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr) Return(testCase.routeAddErr)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark) err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
if testCase.err != nil { if testCase.err != nil {
require.Error(t, err) require.Error(t, err)
+10 -8
View File
@@ -7,15 +7,17 @@ import (
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32, func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family int, family uint8,
) (cleanup func() error, err error) { ) (cleanup func() error, err error) {
rule := netlink.NewRule() rule := netlink.Rule{
rule.Invert = true Priority: &rulePriority,
rule.Priority = rulePriority Family: family,
rule.Mark = firewallMark Table: firewallMark,
rule.Table = int(firewallMark) Mark: &firewallMark,
rule.Family = family Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
if err := w.netlink.RuleAdd(rule); err != nil { if err := w.netlink.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") { if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " + w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
+15 -13
View File
@@ -8,15 +8,14 @@ import (
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
) )
func Test_Wireguard_addRule(t *testing.T) { func Test_Wireguard_addRule(t *testing.T) {
t.Parallel() t.Parallel()
const rulePriority = 987 const rulePriority uint32 = 987
const firewallMark = 456 const firewallMark uint32 = 456
const family = unix.AF_INET const family = netlink.FamilyV4
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
}{ }{
"success": { "success": {
expectedRule: netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Priority: ptrTo(rulePriority),
Priority: rulePriority, Mark: ptrTo(firewallMark),
Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Family: family, Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}, },
}, },
"rule add error": { "rule add error": {
expectedRule: netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Priority: ptrTo(rulePriority),
Priority: rulePriority, Mark: ptrTo(firewallMark),
Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Family: family, Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}, },
ruleAddErr: errDummy, ruleAddErr: errDummy,
err: errors.New("adding ip rule 987: from all to all table 456: dummy"), err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
}, },
"rule delete error": { "rule delete error": {
expectedRule: netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Priority: ptrTo(rulePriority),
Priority: rulePriority, Mark: ptrTo(firewallMark),
Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Family: family, Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}, },
ruleDelErr: errDummy, ruleDelErr: errDummy,
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"), cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
+38 -35
View File
@@ -14,6 +14,7 @@ import (
) )
var ( var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device") ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link") ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link") ErrFindLink = errors.New("cannot find link")
@@ -32,7 +33,11 @@ var (
// See https://git.zx2c4.com/wireguard-go/tree/main.go // See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported := w.netlink.IsWireguardSupported() kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
setupFunction := setupUserSpace setupFunction := setupUserSpace
switch w.settings.Implementation { switch w.settings.Implementation {
@@ -65,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger) defer closers.cleanup(w.logger)
link, waitAndCleanup, err := setupFunction(ctx, linkIndex, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger) w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
if err != nil { if err != nil {
waitError <- err waitError <- err
return return
} }
err = w.addAddresses(link, w.settings.Addresses) err = w.addAddresses(linkIndex, w.settings.Addresses)
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err) waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return return
@@ -85,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return return
} }
linkIndex, err := w.netlink.LinkSetUp(link) err = w.netlink.LinkSetUp(linkIndex)
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return return
} }
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error { closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link) return w.netlink.LinkSetDown(linkIndex)
}) })
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark) err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err) waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return return
@@ -131,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
type waitAndCleanupFunc func() error type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context, func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16, interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) ( closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error, linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) { ) {
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: mtu,
}
links, err := netLinker.LinkList() links, err := netLinker.LinkList()
if err != nil { if err != nil {
return link, nil, fmt.Errorf("listing links: %w", err) return 0, nil, fmt.Errorf("listing links: %w", err)
} }
// Cleanup any previous Wireguard interface with the same name // Cleanup any previous Wireguard interface with the same name
// See https://github.com/qdm12/gluetun/issues/1669 // See https://github.com/qdm12/gluetun/issues/1669
for _, link := range links { for _, link := range links {
if link.Type == "wireguard" && link.Name == interfaceName { if link.VirtualType == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link) err = netLinker.LinkDel(link.Index)
if err != nil { if err != nil {
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w", return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
interfaceName, err) interfaceName, err)
} }
} }
} }
linkIndex, err := netLinker.LinkAdd(link) link := netlink.Link{
if err != nil { VirtualType: "wireguard",
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err) Name: interfaceName,
MTU: mtu,
}
linkIndex, err = netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
} }
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error { closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link) return netLinker.LinkDel(linkIndex)
}) })
waitAndCleanup = func() error { waitAndCleanup = func() error {
@@ -172,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
return ctx.Err() return ctx.Err()
} }
return link, waitAndCleanup, nil return linkIndex, waitAndCleanup, nil
} }
func setupUserSpace(ctx context.Context, func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16, interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) ( closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error, linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) { ) {
tun, err := tun.CreateTUN(interfaceName, int(mtu)) tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil { if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
} }
closers.add("closing TUN device", stepSeven, tun.Close) closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name() tunName, err := tun.Name()
if err != nil { if err != nil {
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName { } else if tunName != interfaceName {
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName) ErrCreateTun, interfaceName, tunName)
} }
link, err = netLinker.LinkByName(interfaceName) link, err := netLinker.LinkByName(interfaceName)
if err != nil { if err != nil {
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
} }
closers.add("deleting link", stepFive, func() error { closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link) return netLinker.LinkDel(link.Index)
}) })
bind := conn.NewDefaultBind() bind := conn.NewDefaultBind()
@@ -217,14 +220,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := uapiOpen(interfaceName) uapiFile, err := uapiOpen(interfaceName)
if err != nil { if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
} }
closers.add("closing UAPI file", stepThree, uapiFile.Close) closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile) uapiListener, err := uapiListen(interfaceName, uapiFile)
if err != nil { if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
} }
closers.add("closing UAPI listener", stepTwo, uapiListener.Close) closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
@@ -249,7 +252,7 @@ func setupUserSpace(ctx context.Context,
return err return err
} }
return link, waitAndCleanup, nil return link.Index, waitAndCleanup, nil
} }
func acceptAndHandle(uapi net.Listener, device *device.Device, func acceptAndHandle(uapi net.Listener, device *device.Device,
+2 -2
View File
@@ -38,10 +38,10 @@ type Settings struct {
FirewallMark uint32 FirewallMark uint32
// Maximum Transmission Unit (MTU) setting for the network interface. // Maximum Transmission Unit (MTU) setting for the network interface.
// It defaults to device.DefaultMTU from wireguard-go which is 1420 // It defaults to device.DefaultMTU from wireguard-go which is 1420
MTU uint16 MTU uint32
// RulePriority is the priority for the rule created with the // RulePriority is the priority for the rule created with the
// FirewallMark. // FirewallMark.
RulePriority int RulePriority uint32
// IPv6 can bet set to true if IPv6 should be handled. // IPv6 can bet set to true if IPv6 should be handled.
// It defaults to false if left unset. // It defaults to false if left unset.
IPv6 *bool IPv6 *bool