mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
chore(all): replace netlink library for more flexibility (#3107)
This commit is contained in:
@@ -3,26 +3,20 @@ package wireguard
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
||||
func (w *Wireguard) addAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix,
|
||||
) (err error) {
|
||||
for _, ipNet := range addresses {
|
||||
if !*w.settings.IPv6 && ipNet.Addr().Is6() {
|
||||
for _, address := range addresses {
|
||||
if !*w.settings.IPv6 && address.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
|
||||
address := netlink.Addr{
|
||||
Network: ipNet,
|
||||
}
|
||||
|
||||
err = w.netlink.AddrReplace(link, address)
|
||||
err = w.netlink.AddrReplace(linkIndex, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
||||
err, address, link.Name)
|
||||
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||
err, address, linkIndex)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -20,21 +19,21 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
linkIndex uint32
|
||||
addrs []netip.Prefix
|
||||
wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
|
||||
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: netlink.Link{Type: "wireguard"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(nil)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(nil).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -45,12 +44,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"first add error": {
|
||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(errDummy)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -59,18 +58,18 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||
},
|
||||
"second add error": {
|
||||
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(nil)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(errDummy).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
@@ -79,11 +78,11 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||
},
|
||||
"ignore IPv6": {
|
||||
addrs: []netip.Prefix{ipNetTwo},
|
||||
wgBuilder: func(_ *gomock.Controller, _ netlink.Link) *Wireguard {
|
||||
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
||||
return &Wireguard{
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(false),
|
||||
@@ -98,9 +97,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
wg := testCase.wgBuilder(ctrl, testCase.link)
|
||||
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
||||
|
||||
err := wg.addAddresses(testCase.link, testCase.addrs)
|
||||
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -1,3 +1,53 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func ptrTo[T any](x T) *T { return &x }
|
||||
|
||||
var rng = rand.New(rand.NewChaCha8([32]byte{})) //nolint:gosec,gochecknoglobals
|
||||
|
||||
func makeLinkName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||
b := make([]byte, 8)
|
||||
for i := range b {
|
||||
b[i] = alphabet[rng.IntN(len(alphabet))]
|
||||
}
|
||||
return "test" + string(b)
|
||||
}
|
||||
|
||||
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||
ptrsEqual(a.Priority, b.Priority) &&
|
||||
a.Table == b.Table &&
|
||||
a.Family == b.Family &&
|
||||
a.Flags == b.Flags &&
|
||||
a.Action == b.Action &&
|
||||
ptrsEqual(a.Mark, b.Mark)
|
||||
}
|
||||
|
||||
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
||||
if !a.IsValid() && !b.IsValid() {
|
||||
return true
|
||||
}
|
||||
if !a.IsValid() || !b.IsValid() {
|
||||
return false
|
||||
}
|
||||
return a.Bits() == b.Bits() &&
|
||||
a.Addr().Compare(b.Addr()) == 0
|
||||
}
|
||||
|
||||
func ptrsEqual(a, b *uint32) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build netlink && linux
|
||||
//go:build linux
|
||||
|
||||
package wireguard
|
||||
|
||||
@@ -10,13 +10,16 @@ import (
|
||||
"github.com/qdm12/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type noopDebugLogger struct{}
|
||||
|
||||
func (n noopDebugLogger) Debugf(format string, args ...any) {}
|
||||
func (n noopDebugLogger) Patch(options ...log.Option) {}
|
||||
func (n noopDebugLogger) Debug(_ string) {}
|
||||
func (n noopDebugLogger) Debugf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Info(_ string) {}
|
||||
func (n noopDebugLogger) Error(_ string) {}
|
||||
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -24,15 +27,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
|
||||
link := netlink.Link{
|
||||
Type: "bridge",
|
||||
Name: "test_8081",
|
||||
}
|
||||
|
||||
// Remove any previously created test interface from a crashed/panic
|
||||
// test or test suite run.
|
||||
err := netlinker.LinkDel(link)
|
||||
if err != nil && err.Error() != "invalid argument" {
|
||||
require.NoError(t, err)
|
||||
DeviceType: netlink.DeviceTypeNone,
|
||||
VirtualType: "bridge",
|
||||
Name: makeLinkName(),
|
||||
}
|
||||
|
||||
linkIndex, err := netlinker.LinkAdd(link)
|
||||
@@ -40,7 +37,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
link.Index = linkIndex
|
||||
|
||||
defer func() {
|
||||
err = netlinker.LinkDel(link)
|
||||
err = netlinker.LinkDel(linkIndex)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
@@ -57,17 +54,15 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
}
|
||||
|
||||
const addIterations = 2 // initial + replace
|
||||
|
||||
for i := 0; i < addIterations; i++ {
|
||||
err = wg.addAddresses(link, addresses)
|
||||
for range addIterations {
|
||||
err = wg.addAddresses(link.Index, addresses)
|
||||
require.NoError(t, err)
|
||||
|
||||
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
|
||||
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
||||
for i, netlinkAddress := range netlinkAddresses {
|
||||
require.NotNil(t, netlinkAddress.Network)
|
||||
assert.Equal(t, addresses[i], netlinkAddress.Network)
|
||||
require.Equal(t, len(addresses), len(ipPrefixes))
|
||||
for i, ipPrefix := range ipPrefixes {
|
||||
assert.Equal(t, addresses[i], ipPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,38 +73,41 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
logger: &noopDebugLogger{},
|
||||
}
|
||||
|
||||
rulePriority := 10000
|
||||
const firewallMark = 999
|
||||
const family = unix.AF_INET // ipv4
|
||||
// Unique combination for this test
|
||||
const rulePriority uint32 = 10000
|
||||
const firewallMark uint32 = 12345
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
t.Cleanup(func() {
|
||||
err := cleanup()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
})
|
||||
|
||||
rules, err := netlinker.RuleList(netlink.FamilyV4)
|
||||
require.NoError(t, err)
|
||||
expectedRule := netlink.Rule{
|
||||
Priority: ptrTo(rulePriority),
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Mark: ptrTo(firewallMark),
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
var rule netlink.Rule
|
||||
var ruleFound bool
|
||||
for _, rule = range rules {
|
||||
if rule.Mark == firewallMark {
|
||||
if rulesAreEqual(rule, expectedRule) {
|
||||
ruleFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, ruleFound)
|
||||
expectedRule := netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
}
|
||||
assert.Equal(t, expectedRule, rule)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority,
|
||||
@@ -118,5 +116,5 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 999: file exists")
|
||||
assert.EqualError(t, err, "adding ip rule 10000: from all to all table 12345: netlink receive: file exists")
|
||||
}
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/netlink"
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||
|
||||
type NetLinker interface {
|
||||
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() bool
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family int) (routes []netlink.Route, err error)
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
}
|
||||
|
||||
@@ -23,10 +27,10 @@ type Ruler interface {
|
||||
}
|
||||
|
||||
type Linker interface {
|
||||
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) error
|
||||
LinkDel(link netlink.Link) error
|
||||
LinkSetUp(linkIndex uint32) error
|
||||
LinkSetDown(linkIndex uint32) error
|
||||
LinkDel(linkIndex uint32) error
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -35,7 +36,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
}
|
||||
|
||||
// AddrReplace mocks base method.
|
||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
||||
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -49,11 +50,12 @@ func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock
|
||||
}
|
||||
|
||||
// IsWireguardSupported mocks base method.
|
||||
func (m *MockNetLinker) IsWireguardSupported() bool {
|
||||
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
||||
@@ -63,10 +65,10 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkAdd mocks base method.
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -93,7 +95,7 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkDel mocks base method.
|
||||
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -122,7 +124,7 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
||||
}
|
||||
|
||||
// LinkSetDown mocks base method.
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -136,12 +138,11 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
||||
}
|
||||
|
||||
// LinkSetUp mocks base method.
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||
@@ -165,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||
}
|
||||
|
||||
// RouteList mocks base method.
|
||||
func (m *MockNetLinker) RouteList(arg0 int) ([]netlink.Route, error) {
|
||||
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||
ret0, _ := ret[0].([]netlink.Route)
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
||||
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32,
|
||||
) (err error) {
|
||||
for _, dst := range destinations {
|
||||
err = w.addRoute(link, dst, firewallMark)
|
||||
err = w.addRoute(linkIndex, dst, firewallMark)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32,
|
||||
) (err error) {
|
||||
family := netlink.FamilyV4
|
||||
@@ -37,17 +37,20 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||
family = netlink.FamilyV6
|
||||
}
|
||||
route := netlink.Route{
|
||||
LinkIndex: link.Index,
|
||||
LinkIndex: linkIndex,
|
||||
Dst: dst,
|
||||
Family: family,
|
||||
Table: int(firewallMark),
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"adding route for link %s, destination %s and table %d: %w",
|
||||
link.Name, dst, firewallMark, err)
|
||||
"adding route for link with index %d, destination %s and table %d: %w",
|
||||
linkIndex, dst, firewallMark, err)
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
@@ -23,38 +23,36 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
dst netip.Prefix
|
||||
expectedRoute netlink.Route
|
||||
routeAddErr error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: netlink.Link{
|
||||
Index: linkIndex,
|
||||
},
|
||||
dst: ipPrefix,
|
||||
expectedRoute: netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipPrefix,
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
},
|
||||
},
|
||||
"route add error": {
|
||||
link: netlink.Link{
|
||||
Name: "a_bridge",
|
||||
Index: linkIndex,
|
||||
},
|
||||
dst: ipPrefix,
|
||||
expectedRoute: netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipPrefix,
|
||||
Family: netlink.FamilyV4,
|
||||
Table: firewallMark,
|
||||
Type: netlink.RouteTypeUnicast,
|
||||
Scope: netlink.ScopeUniverse,
|
||||
Proto: netlink.ProtoStatic,
|
||||
},
|
||||
routeAddErr: errDummy,
|
||||
err: errors.New("adding route for link a_bridge, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
||||
err: errors.New("adding route for link with index 88, destination 1.2.3.4/32 and table 51820: dummy"), //nolint:lll
|
||||
},
|
||||
}
|
||||
|
||||
@@ -72,7 +70,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
|
||||
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -7,15 +7,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority int, firewallMark uint32,
|
||||
family int,
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
family uint8,
|
||||
) (cleanup func() error, err error) {
|
||||
rule := netlink.NewRule()
|
||||
rule.Invert = true
|
||||
rule.Priority = rulePriority
|
||||
rule.Mark = firewallMark
|
||||
rule.Table = int(firewallMark)
|
||||
rule.Family = family
|
||||
rule := netlink.Rule{
|
||||
Priority: &rulePriority,
|
||||
Family: family,
|
||||
Table: firewallMark,
|
||||
Mark: &firewallMark,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
if strings.HasSuffix(err.Error(), "file exists") {
|
||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
|
||||
@@ -8,15 +8,14 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority = 987
|
||||
const firewallMark = 456
|
||||
const family = unix.AF_INET
|
||||
const rulePriority uint32 = 987
|
||||
const firewallMark uint32 = 456
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
@@ -29,31 +28,34 @@ func Test_Wireguard_addRule(t *testing.T) {
|
||||
}{
|
||||
"success": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
},
|
||||
"rule add error": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
ruleAddErr: errDummy,
|
||||
err: errors.New("adding ip rule 987: from all to all table 456: dummy"),
|
||||
},
|
||||
"rule delete error": {
|
||||
expectedRule: netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Priority: ptrTo(rulePriority),
|
||||
Mark: ptrTo(firewallMark),
|
||||
Table: firewallMark,
|
||||
Family: family,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
},
|
||||
ruleDelErr: errDummy,
|
||||
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),
|
||||
|
||||
+38
-35
@@ -14,6 +14,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
@@ -32,7 +33,11 @@ var (
|
||||
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
kernelSupported := w.netlink.IsWireguardSupported()
|
||||
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
||||
return
|
||||
}
|
||||
|
||||
setupFunction := setupUserSpace
|
||||
switch w.settings.Implementation {
|
||||
@@ -65,14 +70,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
|
||||
link, waitAndCleanup, err := setupFunction(ctx,
|
||||
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
|
||||
if err != nil {
|
||||
waitError <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addAddresses(link, w.settings.Addresses)
|
||||
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
return
|
||||
@@ -85,17 +90,16 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
return
|
||||
}
|
||||
|
||||
linkIndex, err := w.netlink.LinkSetUp(link)
|
||||
err = w.netlink.LinkSetUp(linkIndex)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
return
|
||||
}
|
||||
link.Index = linkIndex
|
||||
closers.add("shutting down link", stepFour, func() error {
|
||||
return w.netlink.LinkSetDown(link)
|
||||
return w.netlink.LinkSetDown(linkIndex)
|
||||
})
|
||||
|
||||
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
return
|
||||
@@ -131,39 +135,38 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
type waitAndCleanupFunc func() error
|
||||
|
||||
func setupKernelSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger) (
|
||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
) {
|
||||
link = netlink.Link{
|
||||
Type: "wireguard",
|
||||
Name: interfaceName,
|
||||
MTU: mtu,
|
||||
}
|
||||
links, err := netLinker.LinkList()
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("listing links: %w", err)
|
||||
return 0, nil, fmt.Errorf("listing links: %w", err)
|
||||
}
|
||||
|
||||
// Cleanup any previous Wireguard interface with the same name
|
||||
// See https://github.com/qdm12/gluetun/issues/1669
|
||||
for _, link := range links {
|
||||
if link.Type == "wireguard" && link.Name == interfaceName {
|
||||
err = netLinker.LinkDel(link)
|
||||
if link.VirtualType == "wireguard" && link.Name == interfaceName {
|
||||
err = netLinker.LinkDel(link.Index)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
||||
return 0, nil, fmt.Errorf("deleting previous Wireguard link %s: %w",
|
||||
interfaceName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
linkIndex, err := netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
link := netlink.Link{
|
||||
VirtualType: "wireguard",
|
||||
Name: interfaceName,
|
||||
MTU: mtu,
|
||||
}
|
||||
linkIndex, err = netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
}
|
||||
link.Index = linkIndex
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
return netLinker.LinkDel(link)
|
||||
return netLinker.LinkDel(linkIndex)
|
||||
})
|
||||
|
||||
waitAndCleanup = func() error {
|
||||
@@ -172,35 +175,35 @@ func setupKernelSpace(ctx context.Context,
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return link, waitAndCleanup, nil
|
||||
return linkIndex, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func setupUserSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger) (
|
||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
) {
|
||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
} else if tunName != interfaceName {
|
||||
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err = netLinker.LinkByName(interfaceName)
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
return netLinker.LinkDel(link)
|
||||
return netLinker.LinkDel(link.Index)
|
||||
})
|
||||
|
||||
bind := conn.NewDefaultBind()
|
||||
@@ -217,14 +220,14 @@ func setupUserSpace(ctx context.Context,
|
||||
|
||||
uapiFile, err := uapiOpen(interfaceName)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
|
||||
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
||||
if err != nil {
|
||||
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
@@ -249,7 +252,7 @@ func setupUserSpace(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
return link, waitAndCleanup, nil
|
||||
return link.Index, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
|
||||
@@ -38,10 +38,10 @@ type Settings struct {
|
||||
FirewallMark uint32
|
||||
// Maximum Transmission Unit (MTU) setting for the network interface.
|
||||
// It defaults to device.DefaultMTU from wireguard-go which is 1420
|
||||
MTU uint16
|
||||
MTU uint32
|
||||
// RulePriority is the priority for the rule created with the
|
||||
// FirewallMark.
|
||||
RulePriority int
|
||||
RulePriority uint32
|
||||
// IPv6 can bet set to true if IPv6 should be handled.
|
||||
// It defaults to false if left unset.
|
||||
IPv6 *bool
|
||||
|
||||
Reference in New Issue
Block a user