mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-09 20:29:23 +02:00
chore!(amneziawg): refactor to be separate from wireguard
- amneziawg is now a VPN protocol and no longer a Wireguard implementation - Use it with VPN_TYPE=amneziawg - document AMNEZIAWG_* options in Dockerfile - document amneziawg support in readme - separate amneziawg settings and code from wireguard - re-use code from wireguard whenever possible
This commit is contained in:
@@ -5,15 +5,16 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix,
|
||||
func AddAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix, ipv6 bool,
|
||||
netlink NetLinker,
|
||||
) (err error) {
|
||||
for _, address := range addresses {
|
||||
if !*w.settings.IPv6 && address.Addr().Is6() {
|
||||
if !ipv6 && address.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
|
||||
err = w.netlink.AddrReplace(linkIndex, address)
|
||||
err = netlink.AddrReplace(linkIndex, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||
err, address, linkIndex)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
func Test_AddAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
||||
@@ -19,15 +19,17 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
linkIndex uint32
|
||||
addrs []netip.Prefix
|
||||
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
||||
err error
|
||||
linkIndex uint32
|
||||
addrs []netip.Prefix
|
||||
ipv6 bool
|
||||
netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
@@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(nil).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
},
|
||||
"first add error": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(errDummy)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||
},
|
||||
"second add error": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
@@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(errDummy).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||
},
|
||||
"ignore IPv6": {
|
||||
addrs: []netip.Prefix{ipNetTwo},
|
||||
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
||||
return &Wireguard{
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(false),
|
||||
},
|
||||
}
|
||||
netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker {
|
||||
return NewMockNetLinker(nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
||||
netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex)
|
||||
|
||||
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
||||
err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AmneziaSettings struct {
|
||||
JunkPacketCount uint16
|
||||
JunkPacketMin uint16
|
||||
JunkPacketMax uint16
|
||||
PaddingS1 uint16
|
||||
PaddingS2 uint16
|
||||
PaddingS3 uint16
|
||||
PaddingS4 uint16
|
||||
HeaderH1 string
|
||||
HeaderH2 string
|
||||
HeaderH3 string
|
||||
HeaderH4 string
|
||||
InitPacketI1 string
|
||||
InitPacketI2 string
|
||||
InitPacketI3 string
|
||||
InitPacketI4 string
|
||||
InitPacketI5 string
|
||||
}
|
||||
|
||||
func (s AmneziaSettings) uapiConfig() string {
|
||||
uintFields := map[string]uint16{
|
||||
"jc": s.JunkPacketCount,
|
||||
"jmin": s.JunkPacketMin,
|
||||
"jmax": s.JunkPacketMax,
|
||||
"s1": s.PaddingS1,
|
||||
"s2": s.PaddingS2,
|
||||
"s3": s.PaddingS3,
|
||||
"s4": s.PaddingS4,
|
||||
}
|
||||
stringFields := map[string]string{
|
||||
"h1": s.HeaderH1,
|
||||
"h2": s.HeaderH2,
|
||||
"h3": s.HeaderH3,
|
||||
"h4": s.HeaderH4,
|
||||
"i1": s.InitPacketI1,
|
||||
"i2": s.InitPacketI2,
|
||||
"i3": s.InitPacketI3,
|
||||
"i4": s.InitPacketI4,
|
||||
"i5": s.InitPacketI5,
|
||||
}
|
||||
lines := make([]string, 0, len(uintFields)+len(stringFields))
|
||||
|
||||
for key, val := range uintFields {
|
||||
lines = append(lines, fmt.Sprintf("%s=%d", key, val))
|
||||
}
|
||||
|
||||
for key, val := range stringFields {
|
||||
lines = append(lines, key+"="+val)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import "sort"
|
||||
|
||||
type closer struct {
|
||||
operation string
|
||||
step step
|
||||
close func() error
|
||||
closed bool
|
||||
}
|
||||
|
||||
type closers []closer
|
||||
|
||||
func (c *closers) add(operation string, step step,
|
||||
closeFunc func() error,
|
||||
) {
|
||||
closer := closer{
|
||||
operation: operation,
|
||||
step: step,
|
||||
close: closeFunc,
|
||||
}
|
||||
*c = append(*c, closer)
|
||||
}
|
||||
|
||||
func (c *closers) cleanup(logger Logger) {
|
||||
closers := *c
|
||||
|
||||
sort.Slice(closers, func(i, j int) bool {
|
||||
return closers[i].step < closers[j].step
|
||||
})
|
||||
|
||||
for i, closer := range closers {
|
||||
if closer.closed {
|
||||
continue
|
||||
}
|
||||
closers[i].closed = true
|
||||
logger.Debug(closer.operation + "...")
|
||||
err := closer.close()
|
||||
if err != nil {
|
||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type step int
|
||||
|
||||
const (
|
||||
// stepOne closes the wireguard controller client,
|
||||
// and removes the IP rule.
|
||||
stepOne step = iota
|
||||
// stepTwo closes the UAPI listener.
|
||||
stepTwo
|
||||
// stepThree closes the UAPI file.
|
||||
stepThree
|
||||
// stepFour shuts down the Wireguard link.
|
||||
stepFour
|
||||
// stepFive removes the Wireguard link.
|
||||
stepFive
|
||||
// stepSix closes the Wireguard device.
|
||||
stepSix
|
||||
// stepSeven closes the bind connection and the TUN device file.
|
||||
stepSeven
|
||||
)
|
||||
@@ -1,57 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_closers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||
var (
|
||||
AErr error
|
||||
BErr = errors.New("B failed")
|
||||
CErr = errors.New("C failed")
|
||||
)
|
||||
|
||||
var closers closers
|
||||
closers.add("closing A", stepFive, func() error {
|
||||
ACloseCalled = true
|
||||
return AErr
|
||||
})
|
||||
|
||||
closers.add("closing B", stepThree, func() error {
|
||||
BCloseCalled = true
|
||||
return BErr
|
||||
})
|
||||
|
||||
closers.add("closing C", stepTwo, func() error {
|
||||
CCloseCalled = true
|
||||
return CErr
|
||||
})
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
prevCall := logger.EXPECT().Debug("closing C...")
|
||||
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
|
||||
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
|
||||
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
|
||||
logger.EXPECT().Debug("closing A...").After(prevCall)
|
||||
|
||||
closers.cleanup(logger)
|
||||
|
||||
closers.cleanup(logger) // run twice should not close already closed
|
||||
|
||||
for _, closer := range closers {
|
||||
assert.True(t, closer.closed)
|
||||
}
|
||||
|
||||
assert.True(t, ACloseCalled)
|
||||
assert.True(t, BCloseCalled)
|
||||
assert.True(t, CCloseCalled)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type tunDevice interface {
|
||||
Close() error
|
||||
Name() (string, error)
|
||||
}
|
||||
|
||||
type bind interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type userspaceDevice interface {
|
||||
Close()
|
||||
Wait() chan struct{}
|
||||
IpcHandle(net.Conn)
|
||||
IpcSet(string) error
|
||||
}
|
||||
|
||||
type userSpaceBackend struct {
|
||||
createTun func(string, int) (tunDevice, error)
|
||||
createBind func() bind
|
||||
createDevice func(tunDevice, bind, Logger) userspaceDevice
|
||||
preStart func(userspaceDevice, Settings) error
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
deviceConfig, err := makeDeviceConfig(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making device configuration: %w", err)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
||||
|
||||
type Logger interface {
|
||||
@@ -7,5 +11,16 @@ type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
Errorf(format string, args ...interface{})
|
||||
Erroer
|
||||
}
|
||||
|
||||
type Erroer interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
|
||||
return &device.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
func Test_makeDeviceLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
|
||||
logger.EXPECT().Debugf("test %d", 1)
|
||||
deviceLogger.Verbosef("test %d", 1)
|
||||
|
||||
logger.EXPECT().Errorf("test %d", 2)
|
||||
deviceLogger.Errorf("test %d", 2)
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {}
|
||||
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
func Test_AddAddresses_Integration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
@@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
|
||||
const addIterations = 2 // initial + replace
|
||||
for range addIterations {
|
||||
err = wg.addAddresses(link.Index, addresses)
|
||||
err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||
@@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
func Test_AddRule_Integration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
logger: &noopDebugLogger{},
|
||||
}
|
||||
logger := &noopDebugLogger{}
|
||||
netlinker := netlink.New(logger)
|
||||
|
||||
// Unique combination for this test
|
||||
const rulePriority uint32 = 10000
|
||||
const firewallMark uint32 = 12345
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
cleanup, err := AddRule(rulePriority,
|
||||
firewallMark, family, netlinker, logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := cleanup()
|
||||
@@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
require.True(t, ruleFound)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
nilCleanup, err := AddRule(rulePriority,
|
||||
firewallMark, family, netlinker, logger)
|
||||
if nilCleanup != nil {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
|
||||
@@ -8,17 +8,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32,
|
||||
func AddRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32, netlinker NetLinker, logger Erroer,
|
||||
) (err error) {
|
||||
for _, dst := range destinations {
|
||||
err = w.addRoute(linkIndex, dst, firewallMark)
|
||||
err = addRoute(linkIndex, dst, firewallMark, netlinker)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
|
||||
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||
logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||
"Ignoring and continuing execution; "+
|
||||
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
||||
"Full error string: %s", err)
|
||||
@@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32,
|
||||
func addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32, netlinker NetLinker,
|
||||
) (err error) {
|
||||
family := netlink.FamilyV4
|
||||
if dst.Addr().Is6() {
|
||||
@@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
Proto: netlink.ProtoStatic,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
err = netlinker.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"adding route for link with index %d, destination %s and table %d: %w",
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRoute(t *testing.T) {
|
||||
func Test_addRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const linkIndex = 88
|
||||
@@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
||||
err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
family uint8,
|
||||
func AddRule(rulePriority, firewallMark uint32, family uint8,
|
||||
netlinker NetLinker, logger Logger,
|
||||
) (cleanup func() error, err error) {
|
||||
rule := netlink.Rule{
|
||||
Priority: &rulePriority,
|
||||
@@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
if err := netlinker.RuleAdd(rule); err != nil {
|
||||
if strings.HasSuffix(err.Error(), "file exists") {
|
||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
|
||||
}
|
||||
return nil, fmt.Errorf("adding %s: %w", rule, err)
|
||||
}
|
||||
|
||||
cleanup = func() error {
|
||||
err := w.netlink.RuleDel(rule)
|
||||
err := netlinker.RuleDel(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting rule %s: %w", rule, err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
func Test_AddRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority uint32 = 987
|
||||
@@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
||||
Return(testCase.ruleAddErr)
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark, family)
|
||||
cleanup, err := AddRule(rulePriority, firewallMark, family,
|
||||
netLinker, nil)
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
|
||||
+85
-89
@@ -6,39 +6,33 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/cleanup"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
ErrFindDevice = errors.New("cannot find Wireguard device")
|
||||
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
|
||||
ErrWgctrlOpen = errors.New("cannot open wgctrl")
|
||||
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
|
||||
ErrAddAddress = errors.New("cannot add address to wireguard interface")
|
||||
ErrConfigure = errors.New("cannot configure wireguard interface")
|
||||
ErrDeviceInfo = errors.New("cannot get wireguard device information")
|
||||
ErrIfaceUp = errors.New("cannot set the interface to UP")
|
||||
ErrRouteAdd = errors.New("cannot add route for interface")
|
||||
ErrDeviceWaited = errors.New("device waited for")
|
||||
ErrKernelSupport = errors.New("kernel does not support Wireguard")
|
||||
ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG")
|
||||
errKernelSupport = errors.New("kernel does not support Wireguard")
|
||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||
errDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
// Run runs the wireguard interface and waits until the context is done, then it cleans up the
|
||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
||||
waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
userspaceBackend := defaultUserSpaceBackend()
|
||||
setupFunction := setupUserSpaceCommon
|
||||
setupFunction := setupUserSpace
|
||||
switch w.settings.Implementation {
|
||||
case "auto": //nolint:goconst
|
||||
if !kernelSupported {
|
||||
@@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
case "userspace":
|
||||
case "kernelspace":
|
||||
if !kernelSupported {
|
||||
waitError <- fmt.Errorf("%w", ErrKernelSupport)
|
||||
waitError <- fmt.Errorf("%w", errKernelSupport)
|
||||
return
|
||||
}
|
||||
setupFunction = setupKernelSpace
|
||||
case "amneziawg":
|
||||
userspaceBackend = amneziaUserSpaceBackend()
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
|
||||
}
|
||||
|
||||
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
return setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger)
|
||||
}
|
||||
|
||||
Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger)
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{},
|
||||
setup func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error),
|
||||
settings Settings, netlinker NetLinker, logger Logger,
|
||||
) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
|
||||
waitError <- fmt.Errorf("opening wgctrl: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
var closers closers
|
||||
closers.add("closing controller client", stepOne, client.Close)
|
||||
var cleanups cleanup.Cleanups
|
||||
cleanups.Add("closing controller client", 1, client.Close)
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
defer cleanups.Cleanup(logger)
|
||||
|
||||
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend)
|
||||
linkIndex, waitAndCleanup, err := setup(ctx, &cleanups)
|
||||
if err != nil {
|
||||
waitError <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
||||
err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
waitError <- fmt.Errorf("adding addresses to interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("Connecting to " + w.settings.Endpoint.String())
|
||||
err = configureDevice(client, w.settings)
|
||||
logger.Info("Connecting to " + settings.Endpoint.String())
|
||||
err = ConfigureDevice(client, settings)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
|
||||
waitError <- fmt.Errorf("configuring interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = w.netlink.LinkSetUp(linkIndex)
|
||||
err = netlinker.LinkSetUp(linkIndex)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
waitError <- fmt.Errorf("setting the interface UP: %w", err)
|
||||
return
|
||||
}
|
||||
closers.add("shutting down link", stepFour, func() error {
|
||||
return w.netlink.LinkSetDown(linkIndex)
|
||||
cleanups.Add("shutting down link", 4, func() error {
|
||||
return netlinker.LinkSetDown(linkIndex)
|
||||
})
|
||||
|
||||
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
waitError <- fmt.Errorf("adding routes for interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *w.settings.IPv6 {
|
||||
if *settings.IPv6 {
|
||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, netlink.FamilyV6)
|
||||
ruleCleanup6, err := AddRule(settings.RulePriority,
|
||||
settings.FirewallMark, netlink.FamilyV6,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||
return
|
||||
}
|
||||
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
|
||||
cleanups.Add("removing IPv6 rule", 1, ruleCleanup6)
|
||||
}
|
||||
|
||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, netlink.FamilyV4)
|
||||
ruleCleanup, err := AddRule(settings.RulePriority,
|
||||
settings.FirewallMark, netlink.FamilyV4,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("removing IPv4 rule", stepOne, ruleCleanup)
|
||||
w.logger.Info("Wireguard setup is complete. " +
|
||||
"Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " +
|
||||
"Typically i/o timeout errors indicate the Wireguard connection is not working.")
|
||||
cleanups.Add("removing IPv4 rule", 1, ruleCleanup)
|
||||
ready <- struct{}{}
|
||||
|
||||
waitError <- waitAndCleanup()
|
||||
}
|
||||
|
||||
type waitAndCleanupFunc func() error
|
||||
|
||||
func setupKernelSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger, _ Settings, _ userSpaceBackend) (
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
cleanups *cleanup.Cleanups, logger Logger) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
links, err := netLinker.LinkList()
|
||||
if err != nil {
|
||||
@@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context,
|
||||
}
|
||||
linkIndex, err = netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
return 0, nil, fmt.Errorf("adding link: %w", err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
cleanups.Add("deleting link", 5, func() error {
|
||||
return netLinker.LinkDel(linkIndex)
|
||||
})
|
||||
|
||||
waitAndCleanup = func() error {
|
||||
<-ctx.Done()
|
||||
closers.cleanup(logger)
|
||||
cleanups.Cleanup(logger)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return linkIndex, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func setupUserSpaceCommon(ctx context.Context,
|
||||
func setupUserSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger,
|
||||
settings Settings, b userSpaceBackend,
|
||||
) (
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
cleanups *cleanup.Cleanups, logger Logger) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
tun, err := b.createTun(interfaceName, int(mtu))
|
||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||
cleanups.Add("closing TUN device", 7, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||
} else if tunName != interfaceName {
|
||||
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, interfaceName, tunName)
|
||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||
errTunNameMismatch, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
cleanups.Add("deleting link", 5, func() error {
|
||||
return netLinker.LinkDel(link.Index)
|
||||
})
|
||||
|
||||
bind := b.createBind()
|
||||
bind := conn.NewDefaultBind()
|
||||
|
||||
closers.add("closing bind", stepSeven, bind.Close)
|
||||
cleanups.Add("closing bind", 7, bind.Close)
|
||||
|
||||
device := b.createDevice(tun, bind, logger)
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
device := device.NewDevice(tun, bind, deviceLogger)
|
||||
|
||||
closers.add("closing Wireguard device", stepSix, func() error {
|
||||
cleanups.Add("closing Wireguard device", 6, func() error {
|
||||
device.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := uapiOpen(interfaceName)
|
||||
uapiFile, err := UAPIOpen(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
|
||||
|
||||
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
||||
uapiListener, err := UAPIListen(interfaceName, uapiFile)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
|
||||
if b.preStart != nil {
|
||||
err = b.preStart(device, settings)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
}
|
||||
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
|
||||
|
||||
// acceptAndHandle exits when uapiListener is closed
|
||||
uapiAcceptErrorCh := make(chan error)
|
||||
@@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context,
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = ErrDeviceWaited
|
||||
err = errDeviceWaited
|
||||
}
|
||||
|
||||
closers.cleanup(logger)
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||
|
||||
@@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context,
|
||||
return link.Index, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device userspaceDevice,
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
uapiAcceptErrorCh chan<- error,
|
||||
) {
|
||||
for { // stopped by uapiFile.Close()
|
||||
|
||||
@@ -46,11 +46,8 @@ type Settings struct {
|
||||
// It defaults to false if left unset.
|
||||
IPv6 *bool
|
||||
// Implementation is the implementation to use.
|
||||
// It can be auto, kernelspace, userspace or amneziawg,
|
||||
// and defaults to auto.
|
||||
// It can be auto, kernelspace or userspace, and defaults to auto.
|
||||
Implementation string
|
||||
// AmneziaWG settings are extra obfuscation parameters
|
||||
AmneziaWG AmneziaSettings
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
@@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) {
|
||||
}
|
||||
|
||||
switch s.Implementation {
|
||||
case "auto", "kernelspace", "userspace", "amneziawg":
|
||||
case "auto", "kernelspace", "userspace":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
|
||||
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
wgconn "golang.zx2c4.com/wireguard/conn"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func defaultUserSpaceBackend() userSpaceBackend {
|
||||
return userSpaceBackend{
|
||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
||||
return wgtun.CreateTUN(name, mtu)
|
||||
},
|
||||
createBind: func() bind {
|
||||
return wgconn.NewDefaultBind()
|
||||
},
|
||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
||||
wgtun, _ := td.(wgtun.Device)
|
||||
wgBind, _ := b.(wgconn.Bind)
|
||||
wgLogger := wgdevice.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger)
|
||||
return device
|
||||
},
|
||||
preStart: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func amneziaUserSpaceBackend() userSpaceBackend {
|
||||
return userSpaceBackend{
|
||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
||||
return amneziatun.CreateTUN(name, mtu)
|
||||
},
|
||||
createBind: func() bind {
|
||||
return amneziaconn.NewDefaultBind()
|
||||
},
|
||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
||||
wgamneziaTun, _ := td.(amneziatun.Device)
|
||||
wgamneziaBind, _ := b.(amneziaconn.Bind)
|
||||
wgamneziaLogger := amneziadevice.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger)
|
||||
return device
|
||||
},
|
||||
preStart: func(ud userspaceDevice, s Settings) error {
|
||||
uapiConfig := s.AmneziaWG.uapiConfig()
|
||||
err := ud.IpcSet(uapiConfig)
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(name)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, uapiFile)
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user