chore(all): replace netlink library for more flexibility (#3107)

This commit is contained in:
Quentin McGaw
2026-01-27 10:11:39 +01:00
committed by GitHub
parent e292a4c9be
commit facc6df3be
50 changed files with 1074 additions and 579 deletions
+6 -12
View File
@@ -3,26 +3,20 @@ package wireguard
import (
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addAddresses(link netlink.Link,
func (w *Wireguard) addAddresses(linkIndex uint32,
addresses []netip.Prefix,
) (err error) {
for _, ipNet := range addresses {
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
for _, address := range addresses {
if !*w.settings.IPv6 && address.Addr().Is6() {
continue
}
address := netlink.Addr{
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
err = w.netlink.AddrReplace(linkIndex, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Name)
return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, linkIndex)
}
}
+21 -22
View File
@@ -6,7 +6,6 @@ import (
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
linkIndex uint32
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
err error
}{
"success": {
link: netlink.Link{Type: "wireguard"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
AddrReplace(linkIndex, ipNetTwo).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
},
"first add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
},
"second add error": {
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
AddrReplace(linkIndex, ipNetOne).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
AddrReplace(linkIndex, ipNetTwo).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
}
},
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
},
"ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.link)
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
err := wg.addAddresses(testCase.link, testCase.addrs)
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
if testCase.err != nil {
require.Error(t, err)
+50
View File
@@ -1,3 +1,53 @@
package wireguard
import (
"math/rand/v2"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func ptrTo[T any](x T) *T { return &x }
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
func makeLinkName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
b := make([]byte, 8)
for i := range b {
b[i] = alphabet[rng.IntN(len(alphabet))]
}
return "test" + string(b)
}
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
ptrsEqual(a.Priority, b.Priority) &&
a.Table == b.Table &&
a.Family == b.Family &&
a.Flags == b.Flags &&
a.Action == b.Action &&
ptrsEqual(a.Mark, b.Mark)
}
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if !a.IsValid() || !b.IsValid() {
return false
}
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}
func ptrsEqual(a, b *uint32) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
+34 -36
View File
@@ -1,4 +1,4 @@
//go:build netlink && linux
//go:build linux
package wireguard
@@ -10,13 +10,16 @@ import (
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
type noopDebugLogger struct{}
func (n noopDebugLogger) Debugf(format string, args ...any) {}
func (n noopDebugLogger) Patch(options ...log.Option) {}
func (n noopDebugLogger) Debug(_ string) {}
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
func (n noopDebugLogger) Info(_ string) {}
func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
link := netlink.Link{
Type: "bridge",
Name: "test_8081",
}
// Remove any previously created test interface from a crashed/panic
// test or test suite run.
err := netlinker.LinkDel(link)
if err != nil && err.Error() != "invalid argument" {
require.NoError(t, err)
DeviceType: netlink.DeviceTypeNone,
VirtualType: "bridge",
Name: makeLinkName(),
}
linkIndex, err := netlinker.LinkAdd(link)
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
link.Index = linkIndex
defer func() {
err = netlinker.LinkDel(link)
err = netlinker.LinkDel(linkIndex)
assert.NoError(t, err)
}()
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
}
const addIterations = 2 // initial + replace
for i := 0; i < addIterations; i++ {
err = wg.addAddresses(link, addresses)
for range addIterations {
err = wg.addAddresses(link.Index, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.Network)
assert.Equal(t, addresses[i], netlinkAddress.Network)
require.Equal(t, len(addresses), len(ipPrefixes))
for i, ipPrefix := range ipPrefixes {
assert.Equal(t, addresses[i], ipPrefix)
}
}
}
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{
netlink: netlinker,
logger: &noopDebugLogger{},
}
rulePriority := 10000
const firewallMark = 999
const family = unix.AF_INET // ipv4
// Unique combination for this test
const rulePriority uint32 = 10000
const firewallMark uint32 = 12345
const family = netlink.FamilyV4
cleanup, err := wg.addRule(rulePriority,
firewallMark, family)
require.NoError(t, err)
defer func() {
t.Cleanup(func() {
err := cleanup()
assert.NoError(t, err)
}()
})
rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err)
expectedRule := netlink.Rule{
Priority: ptrTo(rulePriority),
Family: netlink.FamilyV4,
Table: firewallMark,
Mark: ptrTo(firewallMark),
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
var rule netlink.Rule
var ruleFound bool
for _, rule = range rules {
if rule.Mark == firewallMark {
if rulesAreEqual(rule, expectedRule) {
ruleFound = true
break
}
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority,
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
_ = nilCleanup() // in case it succeeds
}
require.Error(t, err)
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
}
+12 -8
View File
@@ -1,19 +1,23 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrReplace(link netlink.Link, addr netlink.Addr) error
AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router
Ruler
Linker
IsWireguardSupported() bool
IsWireguardSupported() (ok bool, err error)
}
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
@@ -23,10 +27,10 @@ type Ruler interface {
}
type Linker interface {
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkDel(linkIndex uint32) error
}
+13 -12
View File
@@ -5,6 +5,7 @@
package wireguard
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() bool {
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(int)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
+10 -7
View File
@@ -8,11 +8,11 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32,
) (err error) {
for _, dst := range destinations {
err = w.addRoute(link, dst, firewallMark)
err = w.addRoute(linkIndex, dst, firewallMark)
if err == nil {
continue
}
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil
}
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32,
) (err error) {
family := netlink.FamilyV4
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
family = netlink.FamilyV6
}
route := netlink.Route{
LinkIndex: link.Index,
LinkIndex: linkIndex,
Dst: dst,
Family: family,
Table: int(firewallMark),
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
}
err = w.netlink.RouteAdd(route)
if err != nil {
return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w",
link.Name, dst, firewallMark, err)
"adding route for link with index %d, destination %s and table %d: %w",
linkIndex, dst, firewallMark, err)
}
return err
+8 -10
View File
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst netip.Prefix
expectedRoute netlink.Route
routeAddErr error
err error
}{
"success": {
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
},
"route add error": {
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipPrefix,
Family: netlink.FamilyV4,
Table: firewallMark,
Type: netlink.RouteTypeUnicast,
Scope: netlink.ScopeUniverse,
Proto: netlink.ProtoStatic,
},
routeAddErr: errDummy,
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
},
}
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
if testCase.err != nil {
require.Error(t, err)
+10 -8
View File
@@ -7,15 +7,17 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
family int,
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family uint8,
) (cleanup func() error, err error) {
rule := netlink.NewRule()
rule.Invert = true
rule.Priority = rulePriority
rule.Mark = firewallMark
rule.Table = int(firewallMark)
rule.Family = family
rule := netlink.Rule{
Priority: &rulePriority,
Family: family,
Table: firewallMark,
Mark: &firewallMark,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
if err := w.netlink.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
+15 -13
View File
@@ -8,15 +8,14 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func Test_Wireguard_addRule(t *testing.T) {
t.Parallel()
const rulePriority = 987
const firewallMark = 456
const family = unix.AF_INET
const rulePriority uint32 = 987
const firewallMark uint32 = 456
const family = netlink.FamilyV4
errDummy := errors.New("dummy")
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
}{
"success": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
},
"rule add error": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleAddErr: errDummy,
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
},
"rule delete error": {
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Priority: ptrTo(rulePriority),
Mark: ptrTo(firewallMark),
Table: firewallMark,
Family: family,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
+38 -35
View File
@@ -14,6 +14,7 @@ import (
)
var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
@@ -32,7 +33,11 @@ var (
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported := w.netlink.IsWireguardSupported()
kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
setupFunction := setupUserSpace
switch w.settings.Implementation {
@@ -65,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger)
link, waitAndCleanup, err := setupFunction(ctx,
linkIndex, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
if err != nil {
waitError <- err
return
}
err = w.addAddresses(link, w.settings.Addresses)
err = w.addAddresses(linkIndex, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return
@@ -85,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
linkIndex, err := w.netlink.LinkSetUp(link)
err = w.netlink.LinkSetUp(linkIndex)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
return w.netlink.LinkSetDown(linkIndex)
})
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
@@ -131,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) {
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: mtu,
}
links, err := netLinker.LinkList()
if err != nil {
return link, nil, fmt.Errorf("listing links: %w", err)
return 0, nil, fmt.Errorf("listing links: %w", err)
}
// Cleanup any previous Wireguard interface with the same name
// See https://github.com/qdm12/gluetun/issues/1669
for _, link := range links {
if link.Type == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link)
if link.VirtualType == "wireguard" && link.Name == interfaceName {
err = netLinker.LinkDel(link.Index)
if err != nil {
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
interfaceName, err)
}
}
}
linkIndex, err := netLinker.LinkAdd(link)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
link := netlink.Link{
VirtualType: "wireguard",
Name: interfaceName,
MTU: mtu,
}
linkIndex, err = netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
return netLinker.LinkDel(linkIndex)
})
waitAndCleanup = func() error {
@@ -172,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
return ctx.Err()
}
return link, waitAndCleanup, nil
return linkIndex, waitAndCleanup, nil
}
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
) {
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
link, err := netLinker.LinkByName(interfaceName)
if err != nil {
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
return netLinker.LinkDel(link.Index)
})
bind := conn.NewDefaultBind()
@@ -217,14 +220,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := uapiOpen(interfaceName)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile)
if err != nil {
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
@@ -249,7 +252,7 @@ func setupUserSpace(ctx context.Context,
return err
}
return link, waitAndCleanup, nil
return link.Index, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device *device.Device,
+2 -2
View File
@@ -38,10 +38,10 @@ type Settings struct {
FirewallMark uint32
// Maximum Transmission Unit (MTU) setting for the network interface.
// It defaults to device.DefaultMTU from wireguard-go which is 1420
MTU uint16
MTU uint32
// RulePriority is the priority for the rule created with the
// FirewallMark.
RulePriority int
RulePriority uint32
// IPv6 can bet set to true if IPv6 should be handled.
// It defaults to false if left unset.
IPv6 *bool