chore!(amneziawg): refactor to be separate from wireguard

- amneziawg is now a VPN protocol and no longer a Wireguard implementation
- Use it with VPN_TYPE=amneziawg
- document AMNEZIAWG_* options in Dockerfile
- document amneziawg support in readme
- separate amneziawg settings and code from wireguard
- re-use code from wireguard whenever possible
This commit is contained in:
Quentin McGaw
2026-03-11 16:35:18 +00:00
parent efea169495
commit b04529c380
54 changed files with 1608 additions and 741 deletions
+5 -4
View File
@@ -5,15 +5,16 @@ import (
"net/netip"
)
func (w *Wireguard) addAddresses(linkIndex uint32,
addresses []netip.Prefix,
func AddAddresses(linkIndex uint32,
addresses []netip.Prefix, ipv6 bool,
netlink NetLinker,
) (err error) {
for _, address := range addresses {
if !*w.settings.IPv6 && address.Addr().Is6() {
if !ipv6 && address.Addr().Is6() {
continue
}
err = w.netlink.AddrReplace(linkIndex, address)
err = netlink.AddrReplace(linkIndex, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, linkIndex)
+19 -34
View File
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addAddresses(t *testing.T) {
func Test_AddAddresses(t *testing.T) {
t.Parallel()
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
@@ -19,15 +19,17 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
linkIndex uint32
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
err error
linkIndex uint32
addrs []netip.Prefix
ipv6 bool
netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker
err error
}{
"success": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
@@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) {
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
},
"first add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
},
"second add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
@@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
},
"ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
},
}
netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker {
return NewMockNetLinker(nil)
},
},
}
@@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex)
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink)
if testCase.err != nil {
require.Error(t, err)
-58
View File
@@ -1,58 +0,0 @@
package wireguard
import (
"fmt"
"strings"
)
type AmneziaSettings struct {
JunkPacketCount uint16
JunkPacketMin uint16
JunkPacketMax uint16
PaddingS1 uint16
PaddingS2 uint16
PaddingS3 uint16
PaddingS4 uint16
HeaderH1 string
HeaderH2 string
HeaderH3 string
HeaderH4 string
InitPacketI1 string
InitPacketI2 string
InitPacketI3 string
InitPacketI4 string
InitPacketI5 string
}
func (s AmneziaSettings) uapiConfig() string {
uintFields := map[string]uint16{
"jc": s.JunkPacketCount,
"jmin": s.JunkPacketMin,
"jmax": s.JunkPacketMax,
"s1": s.PaddingS1,
"s2": s.PaddingS2,
"s3": s.PaddingS3,
"s4": s.PaddingS4,
}
stringFields := map[string]string{
"h1": s.HeaderH1,
"h2": s.HeaderH2,
"h3": s.HeaderH3,
"h4": s.HeaderH4,
"i1": s.InitPacketI1,
"i2": s.InitPacketI2,
"i3": s.InitPacketI3,
"i4": s.InitPacketI4,
"i5": s.InitPacketI5,
}
lines := make([]string, 0, len(uintFields)+len(stringFields))
for key, val := range uintFields {
lines = append(lines, fmt.Sprintf("%s=%d", key, val))
}
for key, val := range stringFields {
lines = append(lines, key+"="+val)
}
return strings.Join(lines, "\n")
}
-63
View File
@@ -1,63 +0,0 @@
package wireguard
import "sort"
type closer struct {
operation string
step step
close func() error
closed bool
}
type closers []closer
func (c *closers) add(operation string, step step,
closeFunc func() error,
) {
closer := closer{
operation: operation,
step: step,
close: closeFunc,
}
*c = append(*c, closer)
}
func (c *closers) cleanup(logger Logger) {
closers := *c
sort.Slice(closers, func(i, j int) bool {
return closers[i].step < closers[j].step
})
for i, closer := range closers {
if closer.closed {
continue
}
closers[i].closed = true
logger.Debug(closer.operation + "...")
err := closer.close()
if err != nil {
logger.Error("failed " + closer.operation + ": " + err.Error())
}
}
}
type step int
const (
// stepOne closes the wireguard controller client,
// and removes the IP rule.
stepOne step = iota
// stepTwo closes the UAPI listener.
stepTwo
// stepThree closes the UAPI file.
stepThree
// stepFour shuts down the Wireguard link.
stepFour
// stepFive removes the Wireguard link.
stepFive
// stepSix closes the Wireguard device.
stepSix
// stepSeven closes the bind connection and the TUN device file.
stepSeven
)
-57
View File
@@ -1,57 +0,0 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_closers(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var ACloseCalled, BCloseCalled, CCloseCalled bool
var (
AErr error
BErr = errors.New("B failed")
CErr = errors.New("C failed")
)
var closers closers
closers.add("closing A", stepFive, func() error {
ACloseCalled = true
return AErr
})
closers.add("closing B", stepThree, func() error {
BCloseCalled = true
return BErr
})
closers.add("closing C", stepTwo, func() error {
CCloseCalled = true
return CErr
})
logger := NewMockLogger(ctrl)
prevCall := logger.EXPECT().Debug("closing C...")
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
logger.EXPECT().Debug("closing A...").After(prevCall)
closers.cleanup(logger)
closers.cleanup(logger) // run twice should not close already closed
for _, closer := range closers {
assert.True(t, closer.closed)
}
assert.True(t, ACloseCalled)
assert.True(t, BCloseCalled)
assert.True(t, CCloseCalled)
}
-28
View File
@@ -1,28 +0,0 @@
package wireguard
import (
"net"
)
type tunDevice interface {
Close() error
Name() (string, error)
}
type bind interface {
Close() error
}
type userspaceDevice interface {
Close()
Wait() chan struct{}
IpcHandle(net.Conn)
IpcSet(string) error
}
type userSpaceBackend struct {
createTun func(string, int) (tunDevice, error)
createBind func() bind
createDevice func(tunDevice, bind, Logger) userspaceDevice
preStart func(userspaceDevice, Settings) error
}
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) {
deviceConfig, err := makeDeviceConfig(settings)
if err != nil {
return fmt.Errorf("making device configuration: %w", err)
+16 -1
View File
@@ -1,5 +1,9 @@
package wireguard
import (
"golang.zx2c4.com/wireguard/device"
)
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
type Logger interface {
@@ -7,5 +11,16 @@ type Logger interface {
Debugf(format string, args ...interface{})
Info(s string)
Error(s string)
Errorf(format string, args ...interface{})
Erroer
}
type Erroer interface {
Errorf(format string, args ...any)
}
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
return &device.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
}
+23
View File
@@ -0,0 +1,23 @@
package wireguard
import (
"testing"
"github.com/golang/mock/gomock"
)
func Test_makeDeviceLogger(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
deviceLogger := makeDeviceLogger(logger)
logger.EXPECT().Debugf("test %d", 1)
deviceLogger.Verbosef("test %d", 1)
logger.EXPECT().Errorf("test %d", 2)
deviceLogger.Errorf("test %d", 2)
}
+9 -12
View File
@@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
func Test_AddAddresses_Integration(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopDebugLogger{})
@@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
const addIterations = 2 // initial + replace
for range addIterations {
err = wg.addAddresses(link.Index, addresses)
err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink)
require.NoError(t, err)
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
@@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
}
}
func Test_netlink_Wireguard_addRule(t *testing.T) {
func Test_AddRule_Integration(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{
netlink: netlinker,
logger: &noopDebugLogger{},
}
logger := &noopDebugLogger{}
netlinker := netlink.New(logger)
// Unique combination for this test
const rulePriority uint32 = 10000
const firewallMark uint32 = 12345
const family = netlink.FamilyV4
cleanup, err := wg.addRule(rulePriority,
firewallMark, family)
cleanup, err := AddRule(rulePriority,
firewallMark, family, netlinker, logger)
require.NoError(t, err)
t.Cleanup(func() {
err := cleanup()
@@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
require.True(t, ruleFound)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority,
firewallMark, family)
nilCleanup, err := AddRule(rulePriority,
firewallMark, family, netlinker, logger)
if nilCleanup != nil {
_ = nilCleanup() // in case it succeeds
}
+7 -7
View File
@@ -8,17 +8,17 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32,
func AddRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32, netlinker NetLinker, logger Erroer,
) (err error) {
for _, dst := range destinations {
err = w.addRoute(linkIndex, dst, firewallMark)
err = addRoute(linkIndex, dst, firewallMark, netlinker)
if err == nil {
continue
}
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
@@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
return nil
}
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32,
func addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32, netlinker NetLinker,
) (err error) {
family := netlink.FamilyV4
if dst.Addr().Is6() {
@@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
Proto: netlink.ProtoStatic,
}
err = w.netlink.RouteAdd(route)
err = netlinker.RouteAdd(route)
if err != nil {
return fmt.Errorf(
"adding route for link with index %d, destination %s and table %d: %w",
+2 -6
View File
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addRoute(t *testing.T) {
func Test_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
@@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) {
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker)
if testCase.err != nil {
require.Error(t, err)
+5 -5
View File
@@ -7,8 +7,8 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family uint8,
func AddRule(rulePriority, firewallMark uint32, family uint8,
netlinker NetLinker, logger Logger,
) (cleanup func() error, err error) {
rule := netlink.Rule{
Priority: &rulePriority,
@@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
if err := w.netlink.RuleAdd(rule); err != nil {
if err := netlinker.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
logger.Info("if you are using Kubernetes, this may fix the error below: " +
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
}
return nil, fmt.Errorf("adding %s: %w", rule, err)
}
cleanup = func() error {
err := w.netlink.RuleDel(rule)
err := netlinker.RuleDel(rule)
if err != nil {
return fmt.Errorf("deleting rule %s: %w", rule, err)
}
+3 -5
View File
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addRule(t *testing.T) {
func Test_AddRule(t *testing.T) {
t.Parallel()
const rulePriority uint32 = 987
@@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) {
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
Return(testCase.ruleAddErr)
cleanup, err := wg.addRule(rulePriority, firewallMark, family)
cleanup, err := AddRule(rulePriority, firewallMark, family,
netLinker, nil)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
+85 -89
View File
@@ -6,39 +6,33 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/cleanup"
"github.com/qdm12/gluetun/internal/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
ErrWgctrlOpen = errors.New("cannot open wgctrl")
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface")
ErrDeviceInfo = errors.New("cannot get wireguard device information")
ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface")
ErrDeviceWaited = errors.New("device waited for")
ErrKernelSupport = errors.New("kernel does not support Wireguard")
ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG")
errKernelSupport = errors.New("kernel does not support Wireguard")
errTunNameMismatch = errors.New("TUN device name is mismatching")
errDeviceWaited = errors.New("device waited for")
)
// Run runs the wireguard interface and waits until the context is done, then it cleans up the
// interface and returns any error that occurred during setup or waiting. It sends an error to
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err)
return
}
userspaceBackend := defaultUserSpaceBackend()
setupFunction := setupUserSpaceCommon
setupFunction := setupUserSpace
switch w.settings.Implementation {
case "auto": //nolint:goconst
if !kernelSupported {
@@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
case "userspace":
case "kernelspace":
if !kernelSupported {
waitError <- fmt.Errorf("%w", ErrKernelSupport)
waitError <- fmt.Errorf("%w", errKernelSupport)
return
}
setupFunction = setupKernelSpace
case "amneziawg":
userspaceBackend = amneziaUserSpaceBackend()
default:
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
}
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
return setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger)
}
Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger)
}
func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{},
setup func(ctx context.Context, cleanups *cleanup.Cleanups) (
linkIndex uint32, waitAndCleanup func() error, err error),
settings Settings, netlinker NetLinker, logger Logger,
) {
client, err := wgctrl.New()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
waitError <- fmt.Errorf("opening wgctrl: %w", err)
return
}
var closers closers
closers.add("closing controller client", stepOne, client.Close)
var cleanups cleanup.Cleanups
cleanups.Add("closing controller client", 1, client.Close)
defer closers.cleanup(w.logger)
defer cleanups.Cleanup(logger)
linkIndex, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend)
linkIndex, waitAndCleanup, err := setup(ctx, &cleanups)
if err != nil {
waitError <- err
return
}
err = w.addAddresses(linkIndex, w.settings.Addresses)
err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
waitError <- fmt.Errorf("adding addresses to interface: %w", err)
return
}
w.logger.Info("Connecting to " + w.settings.Endpoint.String())
err = configureDevice(client, w.settings)
logger.Info("Connecting to " + settings.Endpoint.String())
err = ConfigureDevice(client, settings)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
waitError <- fmt.Errorf("configuring interface: %w", err)
return
}
err = w.netlink.LinkSetUp(linkIndex)
err = netlinker.LinkSetUp(linkIndex)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
waitError <- fmt.Errorf("setting the interface UP: %w", err)
return
}
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(linkIndex)
cleanups.Add("shutting down link", 4, func() error {
return netlinker.LinkSetDown(linkIndex)
})
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
waitError <- fmt.Errorf("adding routes for interface: %w", err)
return
}
if *w.settings.IPv6 {
if *settings.IPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV6)
ruleCleanup6, err := AddRule(settings.RulePriority,
settings.FirewallMark, netlink.FamilyV6,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
return
}
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
cleanups.Add("removing IPv6 rule", 1, ruleCleanup6)
}
ruleCleanup, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV4)
ruleCleanup, err := AddRule(settings.RulePriority,
settings.FirewallMark, netlink.FamilyV4,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
return
}
closers.add("removing IPv4 rule", stepOne, ruleCleanup)
w.logger.Info("Wireguard setup is complete. " +
"Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " +
"Typically i/o timeout errors indicate the Wireguard connection is not working.")
cleanups.Add("removing IPv4 rule", 1, ruleCleanup)
ready <- struct{}{}
waitError <- waitAndCleanup()
}
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger, _ Settings, _ userSpaceBackend) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
cleanups *cleanup.Cleanups, logger Logger) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
links, err := netLinker.LinkList()
if err != nil {
@@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context,
}
linkIndex, err = netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
return 0, nil, fmt.Errorf("adding link: %w", err)
}
closers.add("deleting link", stepFive, func() error {
cleanups.Add("deleting link", 5, func() error {
return netLinker.LinkDel(linkIndex)
})
waitAndCleanup = func() error {
<-ctx.Done()
closers.cleanup(logger)
cleanups.Cleanup(logger)
return ctx.Err()
}
return linkIndex, waitAndCleanup, nil
}
func setupUserSpaceCommon(ctx context.Context,
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger,
settings Settings, b userSpaceBackend,
) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
cleanups *cleanup.Cleanups, logger Logger) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
tun, err := b.createTun(interfaceName, int(mtu))
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
cleanups.Add("closing TUN device", 7, tun.Close)
tunName, err := tun.Name()
if err != nil {
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
} else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
errTunNameMismatch, interfaceName, tunName)
}
link, err := netLinker.LinkByName(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
cleanups.Add("deleting link", 5, func() error {
return netLinker.LinkDel(link.Index)
})
bind := b.createBind()
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
cleanups.Add("closing bind", 7, bind.Close)
device := b.createDevice(tun, bind, logger)
deviceLogger := makeDeviceLogger(logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
cleanups.Add("closing Wireguard device", 6, func() error {
device.Close()
return nil
})
uapiFile, err := uapiOpen(interfaceName)
uapiFile, err := UAPIOpen(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile)
uapiListener, err := UAPIListen(interfaceName, uapiFile)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
if b.preStart != nil {
err = b.preStart(device, settings)
if err != nil {
return 0, nil, err
}
}
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
@@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context,
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
err = errDeviceWaited
}
closers.cleanup(logger)
cleanups.Cleanup(logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
@@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context,
return link.Index, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device userspaceDevice,
func acceptAndHandle(uapi net.Listener, device *device.Device,
uapiAcceptErrorCh chan<- error,
) {
for { // stopped by uapiFile.Close()
+2 -5
View File
@@ -46,11 +46,8 @@ type Settings struct {
// It defaults to false if left unset.
IPv6 *bool
// Implementation is the implementation to use.
// It can be auto, kernelspace, userspace or amneziawg,
// and defaults to auto.
// It can be auto, kernelspace or userspace, and defaults to auto.
Implementation string
// AmneziaWG settings are extra obfuscation parameters
AmneziaWG AmneziaSettings
}
func (s *Settings) SetDefaults() {
@@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) {
}
switch s.Implementation {
case "auto", "kernelspace", "userspace", "amneziawg":
case "auto", "kernelspace", "userspace":
default:
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
}
-58
View File
@@ -1,58 +0,0 @@
package wireguard
import (
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
wgconn "golang.zx2c4.com/wireguard/conn"
wgdevice "golang.zx2c4.com/wireguard/device"
wgtun "golang.zx2c4.com/wireguard/tun"
)
func defaultUserSpaceBackend() userSpaceBackend {
return userSpaceBackend{
createTun: func(name string, mtu int) (tunDevice, error) {
return wgtun.CreateTUN(name, mtu)
},
createBind: func() bind {
return wgconn.NewDefaultBind()
},
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
wgtun, _ := td.(wgtun.Device)
wgBind, _ := b.(wgconn.Bind)
wgLogger := wgdevice.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger)
return device
},
preStart: nil,
}
}
func amneziaUserSpaceBackend() userSpaceBackend {
return userSpaceBackend{
createTun: func(name string, mtu int) (tunDevice, error) {
return amneziatun.CreateTUN(name, mtu)
},
createBind: func() bind {
return amneziaconn.NewDefaultBind()
},
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
wgamneziaTun, _ := td.(amneziatun.Device)
wgamneziaBind, _ := b.(amneziaconn.Bind)
wgamneziaLogger := amneziadevice.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger)
return device
},
preStart: func(ud userspaceDevice, s Settings) error {
uapiConfig := s.AmneziaWG.uapiConfig()
err := ud.IpcSet(uapiConfig)
return err
},
}
}
+2 -2
View File
@@ -7,10 +7,10 @@ import (
"golang.zx2c4.com/wireguard/ipc"
)
func uapiOpen(name string) (*os.File, error) {
func UAPIOpen(name string) (*os.File, error) {
return ipc.UAPIOpen(name)
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, uapiFile)
}
+2 -2
View File
@@ -7,10 +7,10 @@ import (
"os"
)
func uapiOpen(name string) (*os.File, error) {
func UAPIOpen(name string) (*os.File, error) {
panic("not implemented")
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
panic("not implemented")
}