chore!(amneziawg): refactor to be separate from wireguard

- amneziawg is now a VPN protocol and no longer a Wireguard implementation
- Use it with VPN_TYPE=amneziawg
- document AMNEZIAWG_* options in Dockerfile
- document amneziawg support in readme
- separate amneziawg settings and code from wireguard
- re-use code from wireguard whenever possible
This commit is contained in:
Quentin McGaw
2026-03-11 16:35:18 +00:00
parent efea169495
commit b04529c380
54 changed files with 1608 additions and 741 deletions
+4
View File
@@ -60,6 +60,10 @@ linters:
- linters:
- lll
source: "^// https://.+$"
- linters:
- mnd
source: "^ cleanups\\.Add.+$"
path: internal\/(wireguard|amneziawg)\/run\.go
- linters:
- err113
- mnd
+30
View File
@@ -112,6 +112,36 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU= \
WIREGUARD_IMPLEMENTATION=auto \
# Amnezia
AMNEZIAWG_ENDPOINT_IP= \
AMNEZIAWG_ENDPOINT_PORT= \
AMNEZIAWG_CONF_SECRETFILE=/run/secrets/wg0.conf \
AMNEZIAWG_PRIVATE_KEY= \
AMNEZIAWG_PRIVATE_KEY_SECRETFILE=/run/secrets/wireguard_private_key \
AMNEZIAWG_PRESHARED_KEY= \
AMNEZIAWG_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
AMNEZIAWG_PUBLIC_KEY= \
AMNEZIAWG_ALLOWED_IPS= \
AMNEZIAWG_PERSISTENT_KEEPALIVE_INTERVAL=0 \
AMNEZIAWG_ADDRESSES= \
AMNEZIAWG_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
AMNEZIAWG_MTU= \
AMNEZIAWG_JC=0 \
AMNEZIAWG_JMIN=0 \
AMNEZIAWG_JMAX=0 \
AMNEZIAWG_S1=0 \
AMNEZIAWG_S2=0 \
AMNEZIAWG_S3=0 \
AMNEZIAWG_S4=0 \
AMNEZIAWG_H1= \
AMNEZIAWG_H2= \
AMNEZIAWG_H3= \
AMNEZIAWG_H4= \
AMNEZIAWG_I1= \
AMNEZIAWG_I2= \
AMNEZIAWG_I3= \
AMNEZIAWG_I4= \
AMNEZIAWG_I5= \
# Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg)
AMNEZIAWG_JC=0 \
AMNEZIAWG_JMIN=0 \
+1
View File
@@ -67,6 +67,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
- Supports AmneziaWG only with the custom provider for now
- DNS over TLS baked in with service provider(s) of your choice
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
- Choose the vpn network protocol, `udp` or `tcp`
+22
View File
@@ -0,0 +1,22 @@
package amneziawg
type Amneziawg struct {
logger Logger
settings Settings
netlink NetLinker
}
func New(settings Settings, netlink NetLinker,
logger Logger,
) (a *Amneziawg, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
return nil, err
}
return &Amneziawg{
logger: logger,
settings: settings,
netlink: netlink,
}, nil
}
+86
View File
@@ -0,0 +1,86 @@
package amneziawg
import (
"net/netip"
"testing"
"github.com/qdm12/gluetun/internal/wireguard"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/device"
)
func Test_New(t *testing.T) {
t.Parallel()
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
logger := NewMockLogger(nil)
netLinker := NewMockNetLinker(nil)
testCases := map[string]struct {
settings Settings
amneziawg *Amneziawg
err error
}{
"bad_settings": {
settings: Settings{
Wireguard: wireguard.Settings{
PrivateKey: "",
},
},
err: wireguard.ErrPrivateKeyMissing,
},
"minimal valid settings": {
settings: Settings{
Wireguard: wireguard.Settings{
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
FirewallMark: 100,
},
},
amneziawg: &Amneziawg{
logger: logger,
netlink: netLinker,
settings: Settings{
Wireguard: wireguard.Settings{
InterfaceName: "wg0",
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("0.0.0.0/0"),
},
FirewallMark: 100,
MTU: device.DefaultMTU,
IPv6: ptrTo(false),
Implementation: "auto",
},
},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
wireguard, err := New(testCase.settings, netLinker, logger)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.amneziawg, wireguard)
})
}
}
+5
View File
@@ -0,0 +1,5 @@
package amneziawg
func ptrTo[T any](v T) *T {
return &v
}
+11
View File
@@ -0,0 +1,11 @@
package amneziawg
//go:generate mockgen -destination=log_mock_test.go -package amneziawg . Logger
type Logger interface {
Debug(s string)
Debugf(format string, args ...interface{})
Info(s string)
Error(s string)
Errorf(format string, args ...interface{})
}
+104
View File
@@ -0,0 +1,104 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: Logger)
// Package amneziawg is a generated GoMock package.
package amneziawg
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Debugf mocks base method.
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Errorf mocks base method.
func (m *MockLogger) Errorf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Errorf", varargs...)
}
// Errorf indicates an expected call of Errorf.
func (mr *MockLoggerMockRecorder) Errorf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}
+36
View File
@@ -0,0 +1,36 @@
package amneziawg
import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package amneziawg . NetLinker
type NetLinker interface {
AddrReplace(linkIndex uint32, addr netip.Prefix) error
Router
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
}
type Router interface {
RouteList(family uint8) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
}
type Ruler interface {
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(linkIndex uint32) error
LinkSetDown(linkIndex uint32) error
LinkDel(linkIndex uint32) error
}
+209
View File
@@ -0,0 +1,209 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: NetLinker)
// Package amneziawg is a generated GoMock package.
package amneziawg
import (
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/qdm12/gluetun/internal/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
type MockNetLinker struct {
ctrl *gomock.Controller
recorder *MockNetLinkerMockRecorder
}
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
type MockNetLinkerMockRecorder struct {
mock *MockNetLinker
}
// NewMockNetLinker creates a new mock instance.
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
mock := &MockNetLinker{ctrl: ctrl}
mock.recorder = &MockNetLinkerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
return m.recorder
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddrReplace indicates an expected call of AddrReplace.
func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrReplace", reflect.TypeOf((*MockNetLinker)(nil).AddrReplace), arg0, arg1)
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkAdd indicates an expected call of LinkAdd.
func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0)
}
// LinkByName mocks base method.
func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkByName", arg0)
ret0, _ := ret[0].(netlink.Link)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkByName indicates an expected call of LinkByName.
func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0)
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkDel indicates an expected call of LinkDel.
func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0)
}
// LinkList mocks base method.
func (m *MockNetLinker) LinkList() ([]netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkList")
ret0, _ := ret[0].([]netlink.Link)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkList indicates an expected call of LinkList.
func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList))
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetDown indicates an expected call of LinkSetDown.
func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0)
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetUp indicates an expected call of LinkSetUp.
func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLinker)(nil).LinkSetUp), arg0)
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RouteAdd indicates an expected call of RouteAdd.
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0)
ret0, _ := ret[0].([]netlink.Route)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RouteList indicates an expected call of RouteList.
func (mr *MockNetLinkerMockRecorder) RouteList(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLinker)(nil).RouteList), arg0)
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleAdd indicates an expected call of RuleAdd.
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleDel indicates an expected call of RuleDel.
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
}
+133
View File
@@ -0,0 +1,133 @@
package amneziawg
import (
"context"
"errors"
"fmt"
"net"
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/qdm12/gluetun/internal/cleanup"
"github.com/qdm12/gluetun/internal/wireguard"
)
var (
errTunNameMismatch = errors.New("TUN device name is mismatching")
errDeviceWaited = errors.New("device waited for")
)
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
// interface and returns any error that occurred during setup or waiting. It sends an error to
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
// See https://github.com/amnezia-vpn/amneziawg-go/blob/master/main.go
func (a *Amneziawg) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
return setupUserspace(ctx, a.settings.Wireguard.InterfaceName,
a.netlink, a.settings.Wireguard.MTU, cleanups, a.logger, a.settings)
}
wireguard.Run(ctx, waitError, ready, setup, a.settings.Wireguard, a.netlink, a.logger)
}
func setupUserspace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
cleanups *cleanup.Cleanups, logger Logger,
settings Settings,
) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
tun, err := amneziatun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
}
cleanups.Add("closing TUN device", 7, tun.Close)
tunName, err := tun.Name()
if err != nil {
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
} else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
errTunNameMismatch, interfaceName, tunName)
}
link, err := netLinker.LinkByName(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
}
cleanups.Add("deleting link", 5, func() error {
return netLinker.LinkDel(link.Index)
})
bind := amneziaconn.NewDefaultBind()
cleanups.Add("closing bind", 7, bind.Close)
deviceLogger := amneziadevice.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
device := amneziadevice.NewDevice(tun, bind, &deviceLogger)
cleanups.Add("closing Wireguard device", 6, func() error {
device.Close()
return nil
})
uapiFile, err := wireguard.UAPIOpen(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
}
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
uapiListener, err := wireguard.UAPIListen(interfaceName, uapiFile)
if err != nil {
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
}
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
uapiConfig := settings.uapiConfig()
err = device.IpcSet(uapiConfig)
if err != nil {
return 0, nil, fmt.Errorf("setting amneziawg uapi config: %w", err)
}
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
waitAndCleanup = func() error {
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = errDeviceWaited
}
cleanups.Cleanup(logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
return err
}
return link.Index, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device *amneziadevice.Device,
uapiAcceptErrorCh chan<- error,
) {
for { // stopped by uapiFile.Close()
conn, err := uapi.Accept()
if err != nil {
uapiAcceptErrorCh <- err
return
}
go device.IpcHandle(conn)
}
}
@@ -1,11 +1,14 @@
package wireguard
package amneziawg
import (
"fmt"
"strings"
"github.com/qdm12/gluetun/internal/wireguard"
)
type AmneziaSettings struct {
type Settings struct {
Wireguard wireguard.Settings
JunkPacketCount uint16
JunkPacketMin uint16
JunkPacketMax uint16
@@ -24,7 +27,7 @@ type AmneziaSettings struct {
InitPacketI5 string
}
func (s AmneziaSettings) uapiConfig() string {
func (s Settings) uapiConfig() string {
uintFields := map[string]uint16{
"jc": s.JunkPacketCount,
"jmin": s.JunkPacketMin,
@@ -56,3 +59,11 @@ func (s AmneziaSettings) uapiConfig() string {
}
return strings.Join(lines, "\n")
}
func (s *Settings) SetDefaults() {
s.Wireguard.SetDefaults()
}
func (s *Settings) Check() error {
return s.Wireguard.Check()
}
+51
View File
@@ -0,0 +1,51 @@
package cleanup
import "sort"
type Cleanups []cleanup
type cleanup struct {
operation string
orderIndex uint
cleanup func() error
done bool
}
// Add adds a cleanup function to the list of cleanups, with a description of the
// operation being cleaned up, and an order index that determines the order in which
// the cleanup functions are run. The lower the order index, the earlier the cleanup
// function is run.
func (c *Cleanups) Add(operation string, orderIndex uint,
cleanupFunc func() error,
) {
closer := cleanup{
operation: operation,
orderIndex: orderIndex,
cleanup: cleanupFunc,
}
*c = append(*c, closer)
}
// Cleanup runs the cleanup functions in the order of their orderIndex,
// and logs any error that occurs during cleanup.
// It can also be re-called in case a cleanup fails, and already cleaned up
// functions will not be re-run.
func (c *Cleanups) Cleanup(logger Logger) {
closers := *c
sort.Slice(closers, func(i, j int) bool {
return closers[i].orderIndex < closers[j].orderIndex
})
for i, closer := range closers {
if closer.done {
continue
}
closers[i].done = true
logger.Debug(closer.operation + "...")
err := closer.cleanup()
if err != nil {
logger.Error("failed " + closer.operation + ": " + err.Error())
}
}
}
+57
View File
@@ -0,0 +1,57 @@
package cleanup
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_Cleanups(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var ACloseCalled, BCloseCalled, CCloseCalled bool
var (
AErr error
BErr = errors.New("B failed")
CErr = errors.New("C failed")
)
var cleanups Cleanups
cleanups.Add("cleaning up A", 5, func() error {
ACloseCalled = true
return AErr
})
cleanups.Add("cleaning up B", 3, func() error {
BCloseCalled = true
return BErr
})
cleanups.Add("cleaning up C", 2, func() error {
CCloseCalled = true
return CErr
})
logger := NewMockLogger(ctrl)
prevCall := logger.EXPECT().Debug("cleaning up C...")
prevCall = logger.EXPECT().Error("failed cleaning up C: C failed").After(prevCall)
prevCall = logger.EXPECT().Debug("cleaning up B...").After(prevCall)
prevCall = logger.EXPECT().Error("failed cleaning up B: B failed").After(prevCall)
logger.EXPECT().Debug("cleaning up A...").After(prevCall)
cleanups.Cleanup(logger)
cleanups.Cleanup(logger) // run twice should not close already closed
for _, cleanup := range cleanups {
assert.True(t, cleanup.done)
}
assert.True(t, ACloseCalled)
assert.True(t, BCloseCalled)
assert.True(t, CCloseCalled)
}
+6
View File
@@ -0,0 +1,6 @@
package cleanup
type Logger interface {
Debug(string)
Error(string)
}
+3
View File
@@ -0,0 +1,3 @@
package cleanup
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
+58
View File
@@ -0,0 +1,58 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/cleanup (interfaces: Logger)
// Package cleanup is a generated GoMock package.
package cleanup
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
+119 -99
View File
@@ -12,6 +12,9 @@ import (
)
type AmneziaWg struct {
// Wireguard contains the configuration for Wireguard, given
// AmneziaWg is based on Wireguard
Wireguard Wireguard `json:"wireguard"`
JunkPacketCount *uint16 `json:"junk_packet_count"`
JunkPacketMin *uint16 `json:"junk_packet_min"`
JunkPacketMax *uint16 `json:"junk_packet_max"`
@@ -30,15 +33,21 @@ type AmneziaWg struct {
InitPacketI5 *string `json:"init_packet_i5"`
}
func (s *AmneziaWg) read(r *reader.Reader) (err error) {
func (a *AmneziaWg) read(r *reader.Reader) (err error) {
const amneziawg = true
err = a.Wireguard.read(r, amneziawg)
if err != nil {
return err // do not wrap this error
}
uint16Fields := map[string]**uint16{
"AMNEZIAWG_JC": &s.JunkPacketCount,
"AMNEZIAWG_JMIN": &s.JunkPacketMin,
"AMNEZIAWG_JMAX": &s.JunkPacketMax,
"AMNEZIAWG_S1": &s.PaddingS1,
"AMNEZIAWG_S2": &s.PaddingS2,
"AMNEZIAWG_S3": &s.PaddingS3,
"AMNEZIAWG_S4": &s.PaddingS4,
"AMNEZIAWG_JC": &a.JunkPacketCount,
"AMNEZIAWG_JMIN": &a.JunkPacketMin,
"AMNEZIAWG_JMAX": &a.JunkPacketMax,
"AMNEZIAWG_S1": &a.PaddingS1,
"AMNEZIAWG_S2": &a.PaddingS2,
"AMNEZIAWG_S3": &a.PaddingS3,
"AMNEZIAWG_S4": &a.PaddingS4,
}
for key, dst := range uint16Fields {
*dst, err = r.Uint16Ptr(key)
@@ -47,15 +56,15 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
}
}
stringFields := map[string]**string{
"AMNEZIAWG_H1": &s.HeaderH1,
"AMNEZIAWG_H2": &s.HeaderH2,
"AMNEZIAWG_H3": &s.HeaderH3,
"AMNEZIAWG_H4": &s.HeaderH4,
"AMNEZIAWG_I1": &s.InitPacketI1,
"AMNEZIAWG_I2": &s.InitPacketI2,
"AMNEZIAWG_I3": &s.InitPacketI3,
"AMNEZIAWG_I4": &s.InitPacketI4,
"AMNEZIAWG_I5": &s.InitPacketI5,
"AMNEZIAWG_H1": &a.HeaderH1,
"AMNEZIAWG_H2": &a.HeaderH2,
"AMNEZIAWG_H3": &a.HeaderH3,
"AMNEZIAWG_H4": &a.HeaderH4,
"AMNEZIAWG_I1": &a.InitPacketI1,
"AMNEZIAWG_I2": &a.InitPacketI2,
"AMNEZIAWG_I3": &a.InitPacketI3,
"AMNEZIAWG_I4": &a.InitPacketI4,
"AMNEZIAWG_I5": &a.InitPacketI5,
}
opt := reader.ForceLowercase(false)
for key, dst := range stringFields {
@@ -64,80 +73,84 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
return nil
}
func (s AmneziaWg) copy() (copied AmneziaWg) {
func (a AmneziaWg) copy() (copied AmneziaWg) {
return AmneziaWg{
JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount),
JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin),
JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax),
PaddingS1: gosettings.CopyPointer(s.PaddingS1),
PaddingS2: gosettings.CopyPointer(s.PaddingS2),
PaddingS3: gosettings.CopyPointer(s.PaddingS3),
PaddingS4: gosettings.CopyPointer(s.PaddingS4),
HeaderH1: gosettings.CopyPointer(s.HeaderH1),
HeaderH2: gosettings.CopyPointer(s.HeaderH2),
HeaderH3: gosettings.CopyPointer(s.HeaderH3),
HeaderH4: gosettings.CopyPointer(s.HeaderH4),
InitPacketI1: gosettings.CopyPointer(s.InitPacketI1),
InitPacketI2: gosettings.CopyPointer(s.InitPacketI2),
InitPacketI3: gosettings.CopyPointer(s.InitPacketI3),
InitPacketI4: gosettings.CopyPointer(s.InitPacketI4),
InitPacketI5: gosettings.CopyPointer(s.InitPacketI5),
Wireguard: a.Wireguard.copy(),
JunkPacketCount: gosettings.CopyPointer(a.JunkPacketCount),
JunkPacketMin: gosettings.CopyPointer(a.JunkPacketMin),
JunkPacketMax: gosettings.CopyPointer(a.JunkPacketMax),
PaddingS1: gosettings.CopyPointer(a.PaddingS1),
PaddingS2: gosettings.CopyPointer(a.PaddingS2),
PaddingS3: gosettings.CopyPointer(a.PaddingS3),
PaddingS4: gosettings.CopyPointer(a.PaddingS4),
HeaderH1: gosettings.CopyPointer(a.HeaderH1),
HeaderH2: gosettings.CopyPointer(a.HeaderH2),
HeaderH3: gosettings.CopyPointer(a.HeaderH3),
HeaderH4: gosettings.CopyPointer(a.HeaderH4),
InitPacketI1: gosettings.CopyPointer(a.InitPacketI1),
InitPacketI2: gosettings.CopyPointer(a.InitPacketI2),
InitPacketI3: gosettings.CopyPointer(a.InitPacketI3),
InitPacketI4: gosettings.CopyPointer(a.InitPacketI4),
InitPacketI5: gosettings.CopyPointer(a.InitPacketI5),
}
}
//nolint:dupl
func (s *AmneziaWg) overrideWith(other AmneziaWg) {
s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount)
s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin)
s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax)
s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1)
s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2)
s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3)
s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4)
s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1)
s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2)
s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3)
s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4)
s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1)
s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2)
s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3)
s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4)
s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5)
func (a *AmneziaWg) overrideWith(other AmneziaWg) {
a.Wireguard.overrideWith(other.Wireguard)
a.JunkPacketCount = gosettings.OverrideWithPointer(a.JunkPacketCount, other.JunkPacketCount)
a.JunkPacketMin = gosettings.OverrideWithPointer(a.JunkPacketMin, other.JunkPacketMin)
a.JunkPacketMax = gosettings.OverrideWithPointer(a.JunkPacketMax, other.JunkPacketMax)
a.PaddingS1 = gosettings.OverrideWithPointer(a.PaddingS1, other.PaddingS1)
a.PaddingS2 = gosettings.OverrideWithPointer(a.PaddingS2, other.PaddingS2)
a.PaddingS3 = gosettings.OverrideWithPointer(a.PaddingS3, other.PaddingS3)
a.PaddingS4 = gosettings.OverrideWithPointer(a.PaddingS4, other.PaddingS4)
a.HeaderH1 = gosettings.OverrideWithPointer(a.HeaderH1, other.HeaderH1)
a.HeaderH2 = gosettings.OverrideWithPointer(a.HeaderH2, other.HeaderH2)
a.HeaderH3 = gosettings.OverrideWithPointer(a.HeaderH3, other.HeaderH3)
a.HeaderH4 = gosettings.OverrideWithPointer(a.HeaderH4, other.HeaderH4)
a.InitPacketI1 = gosettings.OverrideWithPointer(a.InitPacketI1, other.InitPacketI1)
a.InitPacketI2 = gosettings.OverrideWithPointer(a.InitPacketI2, other.InitPacketI2)
a.InitPacketI3 = gosettings.OverrideWithPointer(a.InitPacketI3, other.InitPacketI3)
a.InitPacketI4 = gosettings.OverrideWithPointer(a.InitPacketI4, other.InitPacketI4)
a.InitPacketI5 = gosettings.OverrideWithPointer(a.InitPacketI5, other.InitPacketI5)
}
func (s *AmneziaWg) setDefaults() {
s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0)
s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0)
s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0)
s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0)
s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0)
s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0)
s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0)
s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "")
s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "")
s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "")
s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "")
s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "")
s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "")
s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "")
s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "")
s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "")
func (a *AmneziaWg) setDefaults(vpnProvider string) {
a.Wireguard.setDefaults(vpnProvider)
a.Wireguard.Implementation = "userspace" // unused except in logs
a.JunkPacketCount = gosettings.DefaultPointer(a.JunkPacketCount, 0)
a.JunkPacketMin = gosettings.DefaultPointer(a.JunkPacketMin, 0)
a.JunkPacketMax = gosettings.DefaultPointer(a.JunkPacketMax, 0)
a.PaddingS1 = gosettings.DefaultPointer(a.PaddingS1, 0)
a.PaddingS2 = gosettings.DefaultPointer(a.PaddingS2, 0)
a.PaddingS3 = gosettings.DefaultPointer(a.PaddingS3, 0)
a.PaddingS4 = gosettings.DefaultPointer(a.PaddingS4, 0)
a.HeaderH1 = gosettings.DefaultPointer(a.HeaderH1, "")
a.HeaderH2 = gosettings.DefaultPointer(a.HeaderH2, "")
a.HeaderH3 = gosettings.DefaultPointer(a.HeaderH3, "")
a.HeaderH4 = gosettings.DefaultPointer(a.HeaderH4, "")
a.InitPacketI1 = gosettings.DefaultPointer(a.InitPacketI1, "")
a.InitPacketI2 = gosettings.DefaultPointer(a.InitPacketI2, "")
a.InitPacketI3 = gosettings.DefaultPointer(a.InitPacketI3, "")
a.InitPacketI4 = gosettings.DefaultPointer(a.InitPacketI4, "")
a.InitPacketI5 = gosettings.DefaultPointer(a.InitPacketI5, "")
}
func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
node = gotree.New("Amneziawg parameters:")
func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
node = gotree.New("AmneziaWG settings:")
node.AppendNode(a.Wireguard.toLinesNode())
uintFields := []struct {
key string
val *uint16
}{
{"jc", s.JunkPacketCount},
{"jmin", s.JunkPacketMin},
{"jmax", s.JunkPacketMax},
{"s1", s.PaddingS1},
{"s2", s.PaddingS2},
{"s3", s.PaddingS3},
{"s4", s.PaddingS4},
{"JC", a.JunkPacketCount},
{"JMIN", a.JunkPacketMin},
{"JMAX", a.JunkPacketMax},
{"S1", a.PaddingS1},
{"S2", a.PaddingS2},
{"S3", a.PaddingS3},
{"S4", a.PaddingS4},
}
for _, f := range uintFields {
node.Appendf("%s: %d", f.key, *f.val)
@@ -147,15 +160,15 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
key string
val *string
}{
{"h1", s.HeaderH1},
{"h2", s.HeaderH2},
{"h3", s.HeaderH3},
{"h4", s.HeaderH4},
{"i1", s.InitPacketI1},
{"i2", s.InitPacketI2},
{"i3", s.InitPacketI3},
{"i4", s.InitPacketI4},
{"i5", s.InitPacketI5},
{"H1", a.HeaderH1},
{"H2", a.HeaderH2},
{"H3", a.HeaderH3},
{"H4", a.HeaderH4},
{"I1", a.InitPacketI1},
{"I2", a.InitPacketI2},
{"I3", a.InitPacketI3},
{"I4", a.InitPacketI4},
{"I5", a.InitPacketI5},
}
for _, f := range stringFields {
node.Appendf("%s: %s", f.key, *f.val)
@@ -165,33 +178,40 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
}
var (
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
ErrHeaderRangeMalformed = errors.New("header range is malformed")
)
func (s AmneziaWg) validate() error {
if *s.JunkPacketCount == 0 {
if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 {
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
const amneziaWG = true
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
if err != nil {
return fmt.Errorf("wireguard settings: %w", err)
}
if *a.JunkPacketCount == 0 {
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
}
} else {
if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 {
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
} else if *s.JunkPacketMin > *s.JunkPacketMax {
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
} else if *a.JunkPacketMin > *a.JunkPacketMax {
return fmt.Errorf("%w: jmin=%d and jmax=%d",
ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax)
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
}
}
nameToHeaderRange := map[string]string{
"h1": *s.HeaderH1,
"h2": *s.HeaderH2,
"h3": *s.HeaderH3,
"h4": *s.HeaderH4,
"h1": *a.HeaderH1,
"h2": *a.HeaderH2,
"h3": *a.HeaderH3,
"h4": *a.HeaderH4,
}
for name, headerRange := range nameToHeaderRange {
if headerRange == "" {
@@ -268,8 +268,6 @@ func (o *OpenVPN) copy() (copied OpenVPN) {
// overrideWith overrides fields of the receiver
// settings object with any field set in the other
// settings.
//
//nolint:dupl
func (o *OpenVPN) overrideWith(other OpenVPN) {
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
o.User = gosettings.OverrideWithPointer(o.User, other.User)
+6 -3
View File
@@ -30,7 +30,10 @@ type Provider struct {
func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) {
// Validate Name
var validNames []string
if vpnType == vpn.OpenVPN {
switch vpnType {
case vpn.AmneziaWg:
validNames = []string{providers.Custom}
case vpn.OpenVPN:
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
@@ -38,7 +41,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
validNames = validNames[:len(validNames)-1]
sort.Strings(validNames)
} else { // Wireguard
case vpn.Wireguard:
validNames = []string{
providers.Airvpn,
providers.Custom,
@@ -52,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
}
}
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
}
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
@@ -87,7 +87,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
filterChoicesGetter FilterChoicesGetter, warner Warner,
) (err error) {
switch ss.VPN {
case vpn.OpenVPN, vpn.Wireguard:
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
}
+29 -7
View File
@@ -16,6 +16,7 @@ type VPN struct {
// empty string in the internal state.
Type string `json:"type"`
Provider Provider `json:"provider"`
AmneziaWg AmneziaWg `json:"amneziawg"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD PMTUD `json:"pmtud"`
@@ -29,10 +30,12 @@ type VPN struct {
DownCommand *string `json:"down_command"`
}
// Validate validates VPN settings, using the filter choices getter (aka servers data storage),
// and if IPv6 is supported or not.
// TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) {
// Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
}
@@ -42,13 +45,20 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
return fmt.Errorf("provider settings: %w", err)
}
if v.Type == vpn.OpenVPN {
switch v.Type {
case vpn.AmneziaWg:
err = v.AmneziaWg.validate(v.Provider.Name, ipv6Supported)
if err != nil {
return fmt.Errorf("AmneziaWG settings: %w", err)
}
case vpn.OpenVPN:
err := v.OpenVPN.validate(v.Provider.Name)
if err != nil {
return fmt.Errorf("OpenVPN settings: %w", err)
}
} else {
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported)
case vpn.Wireguard:
const amneziawg = false
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported, amneziawg)
if err != nil {
return fmt.Errorf("Wireguard settings: %w", err)
}
@@ -66,6 +76,7 @@ func (v *VPN) Copy() (copied VPN) {
return VPN{
Type: v.Type,
Provider: v.Provider.copy(),
AmneziaWg: v.AmneziaWg.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: v.PMTUD.copy(),
@@ -77,6 +88,7 @@ func (v *VPN) Copy() (copied VPN) {
func (v *VPN) OverrideWith(other VPN) {
v.Type = gosettings.OverrideWithComparable(v.Type, other.Type)
v.Provider.overrideWith(other.Provider)
v.AmneziaWg.overrideWith(other.AmneziaWg)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD.overrideWith(other.PMTUD)
@@ -87,6 +99,7 @@ func (v *VPN) OverrideWith(other VPN) {
func (v *VPN) setDefaults() {
v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN)
v.Provider.setDefaults()
v.AmneziaWg.setDefaults(v.Provider.Name)
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD.setDefaults()
@@ -103,9 +116,12 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
node.AppendNode(v.Provider.toLinesNode())
if v.Type == vpn.OpenVPN {
switch v.Type {
case vpn.AmneziaWg:
node.AppendNode(v.AmneziaWg.toLinesNode())
case vpn.OpenVPN:
node.AppendNode(v.OpenVPN.toLinesNode())
} else {
case vpn.Wireguard:
node.AppendNode(v.Wireguard.toLinesNode())
}
node.AppendNode(v.PMTUD.toLinesNode())
@@ -128,12 +144,18 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("VPN provider: %w", err)
}
err = v.AmneziaWg.read(r)
if err != nil {
return fmt.Errorf("AmneziaWG: %w", err)
}
err = v.OpenVPN.read(r)
if err != nil {
return fmt.Errorf("OpenVPN: %w", err)
}
err = v.Wireguard.read(r)
const amneziawg = false
err = v.Wireguard.read(r, amneziawg)
if err != nil {
return fmt.Errorf("wireguard: %w", err)
}
+20 -46
View File
@@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
@@ -42,34 +41,17 @@ type Wireguard struct {
// 0 indicating to use PMTUD.
MTU *uint32 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace", "kernelspace" or "amneziawg".
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
// in the internal state.
Implementation string `json:"implementation"`
// AmneziaWG contains obfuscation parameters
AmneziaWG AmneziaWg `json:"amneziawg"`
}
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
// Validate validates Wireguard settings.
// It should only be ran if the VPN type chosen is Wireguard.
func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) {
if !helpers.IsOneOf(vpnProvider,
providers.Airvpn,
providers.Custom,
providers.Fastestvpn,
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.Protonvpn,
providers.Surfshark,
providers.Windscribe,
) {
// do not validate for VPN provider not supporting Wireguard
return nil
}
// It should only be ran if the VPN type chosen is Wireguard or AmneziaWg.
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
// Validate PrivateKey
if *w.PrivateKey == "" {
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
@@ -138,14 +120,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
}
validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"}
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
validImplementations := []string{"auto", "userspace", "kernelspace"}
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
}
err = w.AmneziaWG.validate()
if err != nil {
return fmt.Errorf("amneziawg settings: %w", err)
}
return nil
@@ -161,7 +140,6 @@ func (w *Wireguard) copy() (copied Wireguard) {
Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation,
AmneziaWG: w.AmneziaWG.copy(),
}
}
@@ -175,7 +153,6 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
w.AmneziaWG.overrideWith(other.AmneziaWG)
}
func (w *Wireguard) setDefaults(vpnProvider string) {
@@ -200,7 +177,6 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
w.AmneziaWG.setDefaults()
}
func (w Wireguard) String() string {
@@ -242,29 +218,27 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
}
if w.Implementation != "auto" {
implNode := node.Appendf("Implementation: %s", w.Implementation)
if w.Implementation == "amneziawg" {
implNode.AppendNode(w.AmneziaWG.toLinesNode())
}
node.Appendf("Implementation: %s", w.Implementation)
}
return node
}
func (w *Wireguard) read(r *reader.Reader) (err error) {
w.PrivateKey = r.Get("WIREGUARD_PRIVATE_KEY", reader.ForceLowercase(false))
w.PreSharedKey = r.Get("WIREGUARD_PRESHARED_KEY", reader.ForceLowercase(false))
func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) {
prefix := "WIREGUARD"
if amneziaWG {
prefix = "AMNEZIAWG"
}
w.PrivateKey = r.Get(prefix+"_PRIVATE_KEY", reader.ForceLowercase(false))
w.PreSharedKey = r.Get(prefix+"_PRESHARED_KEY", reader.ForceLowercase(false))
w.Interface = r.String("VPN_INTERFACE",
reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false))
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
reader.RetroKeys(prefix+"_INTERFACE"), reader.ForceLowercase(false))
err = w.AmneziaWG.read(r)
if err != nil {
return err
if !amneziaWG {
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
}
addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS"))
addressStrings := r.CSV(prefix+"_ADDRESSES", reader.RetroKeys(prefix+"_ADDRESS"))
// WARNING: do not initialize w.Addresses to an empty slice
// or the defaults for nordvpn will not work.
for _, addressString := range addressStrings {
@@ -279,17 +253,17 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
w.Addresses = append(w.Addresses, address)
}
w.AllowedIPs, err = r.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
w.AllowedIPs, err = r.CSVNetipPrefixes(prefix + "_ALLOWED_IPS")
if err != nil {
return err // already wrapped
}
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
w.PersistentKeepaliveInterval, err = r.DurationPtr(prefix + "_PERSISTENT_KEEPALIVE_INTERVAL")
if err != nil {
return err
}
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
w.MTU, err = r.Uint32Ptr(prefix + "_MTU")
if err != nil {
return err
}
@@ -0,0 +1,84 @@
package files
import (
"errors"
"fmt"
"os"
"path/filepath"
"gopkg.in/ini.v1"
)
func (s *Source) lazyLoadAmneziawgConf() AmneziawgConfig {
if s.cached.amneziawgLoaded {
return s.cached.amneziawgConf
}
s.cached.amneziawgLoaded = true
var err error
s.cached.amneziawgConf, err = ParseAmneziawgConf(filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf"))
if err != nil {
s.warner.Warnf("skipping Amneziawg config: %s", err)
}
return s.cached.amneziawgConf
}
type AmneziawgConfig struct {
Wireguard WireguardConfig
Jc *string
Jmin *string
Jmax *string
S1 *string
S2 *string
S3 *string
S4 *string
H1 *string
H2 *string
H3 *string
H4 *string
I1 *string
I2 *string
I3 *string
I4 *string
I5 *string
}
func ParseAmneziawgConf(path string) (config AmneziawgConfig, err error) {
iniFile, err := ini.InsensitiveLoad(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return AmneziawgConfig{}, nil
}
return AmneziawgConfig{}, fmt.Errorf("loading ini from reader: %w", err)
}
config.Wireguard, err = ParseWireguardConf(path)
if err != nil {
return AmneziawgConfig{}, err
}
interfaceSection, err := iniFile.GetSection("Interface")
if err != nil {
// can never happen
return AmneziawgConfig{}, fmt.Errorf("getting interface section: %w", err)
}
config.Jc = getINIKeyFromSection(interfaceSection, "Jc")
config.Jmin = getINIKeyFromSection(interfaceSection, "Jmin")
config.Jmax = getINIKeyFromSection(interfaceSection, "Jmax")
config.S1 = getINIKeyFromSection(interfaceSection, "S1")
config.S2 = getINIKeyFromSection(interfaceSection, "S2")
config.S3 = getINIKeyFromSection(interfaceSection, "S3")
config.S4 = getINIKeyFromSection(interfaceSection, "S4")
config.H1 = getINIKeyFromSection(interfaceSection, "H1")
config.H2 = getINIKeyFromSection(interfaceSection, "H2")
config.H3 = getINIKeyFromSection(interfaceSection, "H3")
config.H4 = getINIKeyFromSection(interfaceSection, "H4")
config.I1 = getINIKeyFromSection(interfaceSection, "I1")
config.I2 = getINIKeyFromSection(interfaceSection, "I2")
config.I3 = getINIKeyFromSection(interfaceSection, "I3")
config.I4 = getINIKeyFromSection(interfaceSection, "I4")
config.I5 = getINIKeyFromSection(interfaceSection, "I5")
return config, nil
}
@@ -0,0 +1,82 @@
package files
import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Source_ParseAmneziawgConf(t *testing.T) {
t.Parallel()
t.Run("no_file", func(t *testing.T) {
t.Parallel()
noFile := filepath.Join(t.TempDir(), "doesnotexist")
wireguard, err := ParseAmneziawgConf(noFile)
assert.Equal(t, AmneziawgConfig{}, wireguard)
assert.NoError(t, err)
})
testCases := map[string]struct {
fileContent string
amneziawg AmneziawgConfig
errMessage string
}{
"ini_load_error": {
fileContent: "invalid",
errMessage: "loading ini from reader: key-value delimiter not found: invalid",
},
"empty_file": {
errMessage: `getting interface section: section "interface" does not exist`,
},
"success": {
fileContent: `
[Interface]
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
Address = 10.38.22.35/32
DNS = 193.138.218.74
Jc = 4
H1 = 721391205
I1 = <b 0x1234>
[Peer]
PresharedKey = YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g=
`,
amneziawg: AmneziawgConfig{
Wireguard: WireguardConfig{
PrivateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
PreSharedKey: ptrTo("YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g="),
Addresses: ptrTo("10.38.22.35/32"),
},
Jc: ptrTo("4"),
H1: ptrTo("721391205"),
I1: ptrTo("<b 0x1234>"),
},
},
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
configFile := filepath.Join(t.TempDir(), "awg.conf")
const permission = fs.FileMode(0o600)
err := os.WriteFile(configFile, []byte(testCase.fileContent), permission)
require.NoError(t, err)
wireguard, err := ParseAmneziawgConf(configFile)
assert.Equal(t, testCase.amneziawg, wireguard)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}
+59 -32
View File
@@ -13,6 +13,8 @@ type Source struct {
cached struct {
wireguardLoaded bool
wireguardConf WireguardConfig
amneziawgLoaded bool
amneziawgConf AmneziawgConfig
}
}
@@ -69,38 +71,11 @@ func (s *Source) Get(key string) (value string, isSet bool) {
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
case "wireguard_endpoint_port":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
case "wireguard_jc":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
case "wireguard_jmin":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
case "wireguard_jmax":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
case "wireguard_s1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
case "wireguard_s2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
case "wireguard_s3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
case "wireguard_s4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
case "wireguard_h1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
case "wireguard_h2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
case "wireguard_h3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
case "wireguard_h4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
case "wireguard_i1":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
case "wireguard_i2":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
case "wireguard_i3":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
case "wireguard_i4":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
case "wireguard_i5":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
}
value, isSet, matched := s.getAmneziawgKey(key)
if matched {
return value, isSet
}
value, isSet, err := ReadFromFile(path)
@@ -110,6 +85,58 @@ func (s *Source) Get(key string) (value string, isSet bool) {
return value, isSet
}
func (s *Source) getAmneziawgKey(key string) (value string, isSet, matched bool) {
switch key {
case "amnezia_private_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
case "amnezia_preshared_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
case "amnezia_addresses":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
case "amnezia_public_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
case "amnezia_endpoint_ip":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
case "amnezia_endpoint_port":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
case "amnezia_jc":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
case "amnezia_jmin":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
case "amnezia_jmax":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
case "amnezia_s1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
case "amnezia_s2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
case "amnezia_s3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
case "amnezia_s4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
case "amnezia_h1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
case "amnezia_h2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
case "amnezia_h3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
case "amnezia_h4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
case "amnezia_i1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
case "amnezia_i2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
case "amnezia_i3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
case "amnezia_i4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
case "amnezia_i5":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
default:
return "", false, false
}
return value, isSet, true
}
func (s *Source) KeyTransform(key string) string {
switch key {
// TODO v4 remove these irregular cases
@@ -25,46 +25,6 @@ func (s *Source) lazyLoadWireguardConf() WireguardConfig {
return s.cached.wireguardConf
}
type amneziaWgConfig struct {
Jc *string
Jmin *string
Jmax *string
S1 *string
S2 *string
S3 *string
S4 *string
H1 *string
H2 *string
H3 *string
H4 *string
I1 *string
I2 *string
I3 *string
I4 *string
I5 *string
}
func parseWireguardAmneziaInterfaceSection(interfaceSection *ini.Section) amneziaWgConfig {
return amneziaWgConfig{
Jc: getINIKeyFromSection(interfaceSection, "Jc"),
Jmin: getINIKeyFromSection(interfaceSection, "Jmin"),
Jmax: getINIKeyFromSection(interfaceSection, "Jmax"),
S1: getINIKeyFromSection(interfaceSection, "S1"),
S2: getINIKeyFromSection(interfaceSection, "S2"),
S3: getINIKeyFromSection(interfaceSection, "S3"),
S4: getINIKeyFromSection(interfaceSection, "S4"),
H1: getINIKeyFromSection(interfaceSection, "H1"),
H2: getINIKeyFromSection(interfaceSection, "H2"),
H3: getINIKeyFromSection(interfaceSection, "H3"),
H4: getINIKeyFromSection(interfaceSection, "H4"),
I1: getINIKeyFromSection(interfaceSection, "I1"),
I2: getINIKeyFromSection(interfaceSection, "I2"),
I3: getINIKeyFromSection(interfaceSection, "I3"),
I4: getINIKeyFromSection(interfaceSection, "I4"),
I5: getINIKeyFromSection(interfaceSection, "I5"),
}
}
type WireguardConfig struct {
PrivateKey *string
PreSharedKey *string
@@ -72,7 +32,6 @@ type WireguardConfig struct {
PublicKey *string
EndpointIP *string
EndpointPort *string
AmneziaParams amneziaWgConfig
}
var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`)
@@ -89,7 +48,6 @@ func ParseWireguardConf(path string) (config WireguardConfig, err error) {
interfaceSection, err := iniFile.GetSection("Interface")
if err == nil {
config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection)
config.AmneziaParams = parseWireguardAmneziaInterfaceSection(interfaceSection)
} else if !regexINISectionNotExist.MatchString(err.Error()) {
// can never happen
return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err)
@@ -100,7 +100,6 @@ func Test_parseWireguardInterfaceSection(t *testing.T) {
iniData string
privateKey *string
addresses *string
amneziaParams amneziaWgConfig
}{
"no_fields": {
iniData: `[Interface]`,
@@ -116,17 +115,9 @@ PrivateKey = x
[Interface]
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
Address = 10.38.22.35/32
Jc = 4
H1 = 721391205
I1 = <b 0x1234>
`,
privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
addresses: ptrTo("10.38.22.35/32"),
amneziaParams: amneziaWgConfig{
Jc: ptrTo("4"),
H1: ptrTo("721391205"),
I1: ptrTo("<b 0x1234>"),
},
},
}
@@ -140,11 +131,9 @@ I1 = <b 0x1234>
require.NoError(t, err)
privateKey, addresses := parseWireguardInterfaceSection(iniSection)
amneziaWgConfig := parseWireguardAmneziaInterfaceSection(iniSection)
assert.Equal(t, testCase.privateKey, privateKey)
assert.Equal(t, testCase.addresses, addresses)
assert.Equal(t, testCase.amneziaParams, amneziaWgConfig)
})
}
}
@@ -0,0 +1,27 @@
package secrets
import (
"os"
"path/filepath"
"github.com/qdm12/gluetun/internal/configuration/sources/files"
)
func (s *Source) lazyLoadAmneziawgConf() files.AmneziawgConfig {
if s.cached.amneziawgLoaded {
return s.cached.amneziawgConf
}
path := os.Getenv("AMNEZIAWG_CONF_SECRETFILE")
if path == "" {
path = filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf")
}
s.cached.amneziawgLoaded = true
var err error
s.cached.amneziawgConf, err = files.ParseAmneziawgConf(path)
if err != nil {
s.warner.Warnf("skipping Amneziawg config: %s", err)
}
return s.cached.amneziawgConf
}
@@ -15,6 +15,8 @@ type Source struct {
cached struct {
wireguardLoaded bool
wireguardConf files.WireguardConfig
amneziawgLoaded bool
amneziawgConf files.AmneziawgConfig
}
}
@@ -60,26 +62,26 @@ func (s *Source) Get(key string) (value string, isSet bool) {
s.warner.Warnf("skipping %s: parsing PEM: %s", path, err)
}
return value, isSet
case "wireguard_private_key":
case "wireguard_private_key", "amneziawg_private_key":
privateKey := s.lazyLoadWireguardConf().PrivateKey
if privateKey != nil {
return *privateKey, true
} // else continue to read from individual secret file
case "wireguard_preshared_key":
case "wireguard_preshared_key", "amneziawg_preshared_key":
preSharedKey := s.lazyLoadWireguardConf().PreSharedKey
if preSharedKey != nil {
return *preSharedKey, true
} // else continue to read from individual secret file
case "wireguard_addresses":
case "wireguard_addresses", "amneziawg_addresses":
addresses := s.lazyLoadWireguardConf().Addresses
if addresses != nil {
return *addresses, true
} // else continue to read from individual secret file
case "wireguard_public_key":
case "wireguard_public_key", "amneziawg_public_key":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().PublicKey)
case "wireguard_endpoint_ip":
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
case "wireguard_endpoint_port":
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
}
@@ -112,38 +114,50 @@ func (s *Source) KeyTransform(key string) string {
func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) {
switch key {
case "wireguard_jc":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
case "wireguard_jmin":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
case "wireguard_jmax":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
case "wireguard_s1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
case "wireguard_s2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
case "wireguard_s3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
case "wireguard_s4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
case "wireguard_h1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
case "wireguard_h2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
case "wireguard_h3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
case "wireguard_h4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
case "wireguard_i1":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
case "wireguard_i2":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
case "wireguard_i3":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
case "wireguard_i4":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
case "wireguard_i5":
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
case "amneziawg_private_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
case "amneziawg_preshared_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
case "wireguard_addresses", "amneziawg_addresses":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
case "wireguard_public_key", "amneziawg_public_key":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
case "amneziawg_jc":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
case "amneziawg_jmin":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
case "amneziawg_jmax":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
case "amneziawg_s1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
case "amneziawg_s2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
case "amneziawg_s3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
case "amneziawg_s4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
case "amneziawg_h1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
case "amneziawg_h2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
case "amneziawg_h3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
case "amneziawg_h4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
case "amneziawg_i1":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
case "amneziawg_i2":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
case "amneziawg_i3":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
case "amneziawg_i4":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
case "amneziawg_i5":
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
default:
return "", false, false
}
+1
View File
@@ -1,6 +1,7 @@
package vpn
const (
AmneziaWg = "amneziawg"
OpenVPN = "openvpn"
Wireguard = "wireguard"
)
+67
View File
@@ -0,0 +1,67 @@
package vpn
import (
"context"
"fmt"
"github.com/qdm12/gluetun/internal/amneziawg"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/wireguard"
"github.com/qdm12/gosettings"
)
// setupAmneziaWg sets AmneziaWG up using the configurators and settings given.
func setupAmneziaWg(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
amneziawger *amneziawg.Amneziawg, connection models.Connection, err error,
) {
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
}
amneziaWGSettings := buildAmneziaWgSettings(connection, settings.AmneziaWg, ipv6Supported)
logger.Debug("Amneziawg server public key: " + amneziaWGSettings.Wireguard.PublicKey)
logger.Debug("Amneziawg client private key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PrivateKey))
logger.Debug("Amneziawg pre-shared key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PreSharedKey))
amneziawger, err = amneziawg.New(amneziaWGSettings, netlinker, logger)
if err != nil {
return nil, models.Connection{}, fmt.Errorf("creating amneziawg: %w", err)
}
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
}
return amneziawger, connection, nil
}
func buildAmneziaWgSettings(connection models.Connection,
userSettings settings.AmneziaWg, ipv6Supported bool,
) amneziawg.Settings {
return amneziawg.Settings{
Wireguard: buildWireguardSettings(connection, userSettings.Wireguard, ipv6Supported),
JunkPacketCount: *userSettings.JunkPacketCount,
JunkPacketMin: *userSettings.JunkPacketMin,
JunkPacketMax: *userSettings.JunkPacketMax,
PaddingS1: *userSettings.PaddingS1,
PaddingS2: *userSettings.PaddingS2,
PaddingS3: *userSettings.PaddingS3,
PaddingS4: *userSettings.PaddingS4,
HeaderH1: *userSettings.HeaderH1,
HeaderH2: *userSettings.HeaderH2,
HeaderH3: *userSettings.HeaderH3,
HeaderH4: *userSettings.HeaderH4,
InitPacketI1: *userSettings.InitPacketI1,
InitPacketI2: *userSettings.InitPacketI2,
InitPacketI3: *userSettings.InitPacketI3,
InitPacketI4: *userSettings.InitPacketI4,
InitPacketI5: *userSettings.InitPacketI5,
}
}
+9 -2
View File
@@ -33,14 +33,21 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
var connection models.Connection
var err error
subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN {
switch settings.Type {
case vpn.AmneziaWg:
vpnInterface = settings.AmneziaWg.Wireguard.Interface
vpnRunner, connection, err = setupAmneziaWg(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
case vpn.OpenVPN:
vpnInterface = settings.OpenVPN.Interface
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.cmder, subLogger)
} else { // Wireguard
case vpn.Wireguard:
vpnInterface = settings.Wireguard.Interface
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
default:
panic("vpn type not implemented: " + settings.Type)
}
if err != nil {
l.crashed(ctx, err)
+9
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/pmtud"
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
@@ -48,6 +49,14 @@ type tunnelUpPMTUDData struct {
}
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
switch vpnType := l.GetSettings().Type; vpnType {
case vpn.Wireguard, vpn.AmneziaWg:
l.logger.Infof("%s setup is complete. "+
"Note %s is a silent protocol and it may or may not work, without giving any error message. "+
"Typically i/o timeout errors indicate the %s connection is not working.",
vpnType, vpnType, vpnType)
}
l.client.CloseIdleConnections()
for _, vpnPort := range l.vpnInputPorts {
-22
View File
@@ -51,7 +51,6 @@ func buildWireguardSettings(connection models.Connection,
settings.PreSharedKey = *userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface
settings.Implementation = userSettings.Implementation
settings.AmneziaWG = buildAmneziaWgSettings(userSettings.AmneziaWG)
if *userSettings.MTU > 0 {
settings.MTU = *userSettings.MTU
} else {
@@ -91,24 +90,3 @@ func buildWireguardSettings(connection models.Connection,
return settings
}
func buildAmneziaWgSettings(s settings.AmneziaWg) wireguard.AmneziaSettings {
return wireguard.AmneziaSettings{
JunkPacketCount: *s.JunkPacketCount,
JunkPacketMin: *s.JunkPacketMin,
JunkPacketMax: *s.JunkPacketMax,
PaddingS1: *s.PaddingS1,
PaddingS2: *s.PaddingS2,
PaddingS3: *s.PaddingS3,
PaddingS4: *s.PaddingS4,
HeaderH1: *s.HeaderH1,
HeaderH2: *s.HeaderH2,
HeaderH3: *s.HeaderH3,
HeaderH4: *s.HeaderH4,
InitPacketI1: *s.InitPacketI1,
InitPacketI2: *s.InitPacketI2,
InitPacketI3: *s.InitPacketI3,
InitPacketI4: *s.InitPacketI4,
InitPacketI5: *s.InitPacketI5,
}
}
-22
View File
@@ -40,24 +40,6 @@ func Test_buildWireguardSettings(t *testing.T) {
PersistentKeepaliveInterval: ptrTo(time.Hour),
Interface: "wg1",
MTU: ptrTo(uint32(1000)),
AmneziaWG: settings.AmneziaWg{
JunkPacketCount: ptrTo(uint16(1)),
JunkPacketMin: ptrTo(uint16(0)),
JunkPacketMax: ptrTo(uint16(0)),
PaddingS1: ptrTo(uint16(0)),
PaddingS2: ptrTo(uint16(0)),
PaddingS3: ptrTo(uint16(0)),
PaddingS4: ptrTo(uint16(0)),
HeaderH1: ptrTo("x"),
HeaderH2: ptrTo(""),
HeaderH3: ptrTo(""),
HeaderH4: ptrTo(""),
InitPacketI1: ptrTo(""),
InitPacketI2: ptrTo(""),
InitPacketI3: ptrTo(""),
InitPacketI4: ptrTo(""),
InitPacketI5: ptrTo(""),
},
},
ipv6Supported: false,
settings: wireguard.Settings{
@@ -76,10 +58,6 @@ func Test_buildWireguardSettings(t *testing.T) {
RulePriority: 101,
IPv6: ptrTo(false),
MTU: 1000,
AmneziaWG: wireguard.AmneziaSettings{
JunkPacketCount: 1,
HeaderH1: "x",
},
},
},
}
+5 -4
View File
@@ -5,15 +5,16 @@ import (
"net/netip"
)
func (w *Wireguard) addAddresses(linkIndex uint32,
addresses []netip.Prefix,
func AddAddresses(linkIndex uint32,
addresses []netip.Prefix, ipv6 bool,
netlink NetLinker,
) (err error) {
for _, address := range addresses {
if !*w.settings.IPv6 && address.Addr().Is6() {
if !ipv6 && address.Addr().Is6() {
continue
}
err = w.netlink.AddrReplace(linkIndex, address)
err = netlink.AddrReplace(linkIndex, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link with index %d",
err, address, linkIndex)
+16 -31
View File
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addAddresses(t *testing.T) {
func Test_AddAddresses(t *testing.T) {
t.Parallel()
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
@@ -21,13 +21,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
testCases := map[string]struct {
linkIndex uint32
addrs []netip.Prefix
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
ipv6 bool
netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker
err error
}{
"success": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
@@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) {
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
},
"first add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
},
"second add error": {
linkIndex: 1,
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
ipv6: true,
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(linkIndex, ipNetOne).
@@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
netLinker.EXPECT().
AddrReplace(linkIndex, ipNetTwo).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
return netLinker
},
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
},
"ignore IPv6": {
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
},
}
netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker {
return NewMockNetLinker(nil)
},
},
}
@@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex)
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink)
if testCase.err != nil {
require.Error(t, err)
-63
View File
@@ -1,63 +0,0 @@
package wireguard
import "sort"
type closer struct {
operation string
step step
close func() error
closed bool
}
type closers []closer
func (c *closers) add(operation string, step step,
closeFunc func() error,
) {
closer := closer{
operation: operation,
step: step,
close: closeFunc,
}
*c = append(*c, closer)
}
func (c *closers) cleanup(logger Logger) {
closers := *c
sort.Slice(closers, func(i, j int) bool {
return closers[i].step < closers[j].step
})
for i, closer := range closers {
if closer.closed {
continue
}
closers[i].closed = true
logger.Debug(closer.operation + "...")
err := closer.close()
if err != nil {
logger.Error("failed " + closer.operation + ": " + err.Error())
}
}
}
type step int
const (
// stepOne closes the wireguard controller client,
// and removes the IP rule.
stepOne step = iota
// stepTwo closes the UAPI listener.
stepTwo
// stepThree closes the UAPI file.
stepThree
// stepFour shuts down the Wireguard link.
stepFour
// stepFive removes the Wireguard link.
stepFive
// stepSix closes the Wireguard device.
stepSix
// stepSeven closes the bind connection and the TUN device file.
stepSeven
)
-57
View File
@@ -1,57 +0,0 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_closers(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var ACloseCalled, BCloseCalled, CCloseCalled bool
var (
AErr error
BErr = errors.New("B failed")
CErr = errors.New("C failed")
)
var closers closers
closers.add("closing A", stepFive, func() error {
ACloseCalled = true
return AErr
})
closers.add("closing B", stepThree, func() error {
BCloseCalled = true
return BErr
})
closers.add("closing C", stepTwo, func() error {
CCloseCalled = true
return CErr
})
logger := NewMockLogger(ctrl)
prevCall := logger.EXPECT().Debug("closing C...")
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
logger.EXPECT().Debug("closing A...").After(prevCall)
closers.cleanup(logger)
closers.cleanup(logger) // run twice should not close already closed
for _, closer := range closers {
assert.True(t, closer.closed)
}
assert.True(t, ACloseCalled)
assert.True(t, BCloseCalled)
assert.True(t, CCloseCalled)
}
-28
View File
@@ -1,28 +0,0 @@
package wireguard
import (
"net"
)
type tunDevice interface {
Close() error
Name() (string, error)
}
type bind interface {
Close() error
}
type userspaceDevice interface {
Close()
Wait() chan struct{}
IpcHandle(net.Conn)
IpcSet(string) error
}
type userSpaceBackend struct {
createTun func(string, int) (tunDevice, error)
createBind func() bind
createDevice func(tunDevice, bind, Logger) userspaceDevice
preStart func(userspaceDevice, Settings) error
}
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) {
deviceConfig, err := makeDeviceConfig(settings)
if err != nil {
return fmt.Errorf("making device configuration: %w", err)
+16 -1
View File
@@ -1,5 +1,9 @@
package wireguard
import (
"golang.zx2c4.com/wireguard/device"
)
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
type Logger interface {
@@ -7,5 +11,16 @@ type Logger interface {
Debugf(format string, args ...interface{})
Info(s string)
Error(s string)
Errorf(format string, args ...interface{})
Erroer
}
type Erroer interface {
Errorf(format string, args ...any)
}
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
return &device.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
}
+23
View File
@@ -0,0 +1,23 @@
package wireguard
import (
"testing"
"github.com/golang/mock/gomock"
)
func Test_makeDeviceLogger(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
deviceLogger := makeDeviceLogger(logger)
logger.EXPECT().Debugf("test %d", 1)
deviceLogger.Verbosef("test %d", 1)
logger.EXPECT().Errorf("test %d", 2)
deviceLogger.Errorf("test %d", 2)
}
+9 -12
View File
@@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {}
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
func (n noopDebugLogger) Patch(_ ...log.Option) {}
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
func Test_AddAddresses_Integration(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopDebugLogger{})
@@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
const addIterations = 2 // initial + replace
for range addIterations {
err = wg.addAddresses(link.Index, addresses)
err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink)
require.NoError(t, err)
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
@@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
}
}
func Test_netlink_Wireguard_addRule(t *testing.T) {
func Test_AddRule_Integration(t *testing.T) {
t.Parallel()
netlinker := netlink.New(&noopDebugLogger{})
wg := &Wireguard{
netlink: netlinker,
logger: &noopDebugLogger{},
}
logger := &noopDebugLogger{}
netlinker := netlink.New(logger)
// Unique combination for this test
const rulePriority uint32 = 10000
const firewallMark uint32 = 12345
const family = netlink.FamilyV4
cleanup, err := wg.addRule(rulePriority,
firewallMark, family)
cleanup, err := AddRule(rulePriority,
firewallMark, family, netlinker, logger)
require.NoError(t, err)
t.Cleanup(func() {
err := cleanup()
@@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
require.True(t, ruleFound)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority,
firewallMark, family)
nilCleanup, err := AddRule(rulePriority,
firewallMark, family, netlinker, logger)
if nilCleanup != nil {
_ = nilCleanup() // in case it succeeds
}
+7 -7
View File
@@ -8,17 +8,17 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32,
func AddRoutes(linkIndex uint32, destinations []netip.Prefix,
firewallMark uint32, netlinker NetLinker, logger Erroer,
) (err error) {
for _, dst := range destinations {
err = w.addRoute(linkIndex, dst, firewallMark)
err = addRoute(linkIndex, dst, firewallMark, netlinker)
if err == nil {
continue
}
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
@@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
return nil
}
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32,
func addRoute(linkIndex uint32, dst netip.Prefix,
firewallMark uint32, netlinker NetLinker,
) (err error) {
family := netlink.FamilyV4
if dst.Addr().Is6() {
@@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
Proto: netlink.ProtoStatic,
}
err = w.netlink.RouteAdd(route)
err = netlinker.RouteAdd(route)
if err != nil {
return fmt.Errorf(
"adding route for link with index %d, destination %s and table %d: %w",
+2 -6
View File
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addRoute(t *testing.T) {
func Test_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
@@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) {
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker)
if testCase.err != nil {
require.Error(t, err)
+5 -5
View File
@@ -7,8 +7,8 @@ import (
"github.com/qdm12/gluetun/internal/netlink"
)
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
family uint8,
func AddRule(rulePriority, firewallMark uint32, family uint8,
netlinker NetLinker, logger Logger,
) (cleanup func() error, err error) {
rule := netlink.Rule{
Priority: &rulePriority,
@@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
Flags: netlink.FlagInvert,
Action: netlink.ActionToTable,
}
if err := w.netlink.RuleAdd(rule); err != nil {
if err := netlinker.RuleAdd(rule); err != nil {
if strings.HasSuffix(err.Error(), "file exists") {
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
logger.Info("if you are using Kubernetes, this may fix the error below: " +
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
}
return nil, fmt.Errorf("adding %s: %w", rule, err)
}
cleanup = func() error {
err := w.netlink.RuleDel(rule)
err := netlinker.RuleDel(rule)
if err != nil {
return fmt.Errorf("deleting rule %s: %w", rule, err)
}
+3 -5
View File
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Wireguard_addRule(t *testing.T) {
func Test_AddRule(t *testing.T) {
t.Parallel()
const rulePriority uint32 = 987
@@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) {
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
Return(testCase.ruleAddErr)
cleanup, err := wg.addRule(rulePriority, firewallMark, family)
cleanup, err := AddRule(rulePriority, firewallMark, family,
netLinker, nil)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
+85 -89
View File
@@ -6,39 +6,33 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/cleanup"
"github.com/qdm12/gluetun/internal/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
ErrWgctrlOpen = errors.New("cannot open wgctrl")
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface")
ErrDeviceInfo = errors.New("cannot get wireguard device information")
ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface")
ErrDeviceWaited = errors.New("device waited for")
ErrKernelSupport = errors.New("kernel does not support Wireguard")
ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG")
errKernelSupport = errors.New("kernel does not support Wireguard")
errTunNameMismatch = errors.New("TUN device name is mismatching")
errDeviceWaited = errors.New("device waited for")
)
// Run runs the wireguard interface and waits until the context is done, then it cleans up the
// interface and returns any error that occurred during setup or waiting. It sends an error to
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err)
return
}
userspaceBackend := defaultUserSpaceBackend()
setupFunction := setupUserSpaceCommon
setupFunction := setupUserSpace
switch w.settings.Implementation {
case "auto": //nolint:goconst
if !kernelSupported {
@@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
case "userspace":
case "kernelspace":
if !kernelSupported {
waitError <- fmt.Errorf("%w", ErrKernelSupport)
waitError <- fmt.Errorf("%w", errKernelSupport)
return
}
setupFunction = setupKernelSpace
case "amneziawg":
userspaceBackend = amneziaUserSpaceBackend()
default:
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
}
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
return setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger)
}
Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger)
}
func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{},
setup func(ctx context.Context, cleanups *cleanup.Cleanups) (
linkIndex uint32, waitAndCleanup func() error, err error),
settings Settings, netlinker NetLinker, logger Logger,
) {
client, err := wgctrl.New()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
waitError <- fmt.Errorf("opening wgctrl: %w", err)
return
}
var closers closers
closers.add("closing controller client", stepOne, client.Close)
var cleanups cleanup.Cleanups
cleanups.Add("closing controller client", 1, client.Close)
defer closers.cleanup(w.logger)
defer cleanups.Cleanup(logger)
linkIndex, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend)
linkIndex, waitAndCleanup, err := setup(ctx, &cleanups)
if err != nil {
waitError <- err
return
}
err = w.addAddresses(linkIndex, w.settings.Addresses)
err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
waitError <- fmt.Errorf("adding addresses to interface: %w", err)
return
}
w.logger.Info("Connecting to " + w.settings.Endpoint.String())
err = configureDevice(client, w.settings)
logger.Info("Connecting to " + settings.Endpoint.String())
err = ConfigureDevice(client, settings)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
waitError <- fmt.Errorf("configuring interface: %w", err)
return
}
err = w.netlink.LinkSetUp(linkIndex)
err = netlinker.LinkSetUp(linkIndex)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
waitError <- fmt.Errorf("setting the interface UP: %w", err)
return
}
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(linkIndex)
cleanups.Add("shutting down link", 4, func() error {
return netlinker.LinkSetDown(linkIndex)
})
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
waitError <- fmt.Errorf("adding routes for interface: %w", err)
return
}
if *w.settings.IPv6 {
if *settings.IPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV6)
ruleCleanup6, err := AddRule(settings.RulePriority,
settings.FirewallMark, netlink.FamilyV6,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
return
}
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
cleanups.Add("removing IPv6 rule", 1, ruleCleanup6)
}
ruleCleanup, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, netlink.FamilyV4)
ruleCleanup, err := AddRule(settings.RulePriority,
settings.FirewallMark, netlink.FamilyV4,
netlinker, logger)
if err != nil {
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
return
}
closers.add("removing IPv4 rule", stepOne, ruleCleanup)
w.logger.Info("Wireguard setup is complete. " +
"Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " +
"Typically i/o timeout errors indicate the Wireguard connection is not working.")
cleanups.Add("removing IPv4 rule", 1, ruleCleanup)
ready <- struct{}{}
waitError <- waitAndCleanup()
}
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger, _ Settings, _ userSpaceBackend) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
cleanups *cleanup.Cleanups, logger Logger) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
links, err := netLinker.LinkList()
if err != nil {
@@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context,
}
linkIndex, err = netLinker.LinkAdd(link)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
return 0, nil, fmt.Errorf("adding link: %w", err)
}
closers.add("deleting link", stepFive, func() error {
cleanups.Add("deleting link", 5, func() error {
return netLinker.LinkDel(linkIndex)
})
waitAndCleanup = func() error {
<-ctx.Done()
closers.cleanup(logger)
cleanups.Cleanup(logger)
return ctx.Err()
}
return linkIndex, waitAndCleanup, nil
}
func setupUserSpaceCommon(ctx context.Context,
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint32,
closers *closers, logger Logger,
settings Settings, b userSpaceBackend,
) (
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
cleanups *cleanup.Cleanups, logger Logger) (
linkIndex uint32, waitAndCleanup func() error, err error,
) {
tun, err := b.createTun(interfaceName, int(mtu))
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
cleanups.Add("closing TUN device", 7, tun.Close)
tunName, err := tun.Name()
if err != nil {
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
} else if tunName != interfaceName {
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
errTunNameMismatch, interfaceName, tunName)
}
link, err := netLinker.LinkByName(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
cleanups.Add("deleting link", 5, func() error {
return netLinker.LinkDel(link.Index)
})
bind := b.createBind()
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
cleanups.Add("closing bind", 7, bind.Close)
device := b.createDevice(tun, bind, logger)
deviceLogger := makeDeviceLogger(logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
cleanups.Add("closing Wireguard device", 6, func() error {
device.Close()
return nil
})
uapiFile, err := uapiOpen(interfaceName)
uapiFile, err := UAPIOpen(interfaceName)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
uapiListener, err := uapiListen(interfaceName, uapiFile)
uapiListener, err := UAPIListen(interfaceName, uapiFile)
if err != nil {
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
if b.preStart != nil {
err = b.preStart(device, settings)
if err != nil {
return 0, nil, err
}
}
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
@@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context,
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
err = errDeviceWaited
}
closers.cleanup(logger)
cleanups.Cleanup(logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
@@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context,
return link.Index, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device userspaceDevice,
func acceptAndHandle(uapi net.Listener, device *device.Device,
uapiAcceptErrorCh chan<- error,
) {
for { // stopped by uapiFile.Close()
+2 -5
View File
@@ -46,11 +46,8 @@ type Settings struct {
// It defaults to false if left unset.
IPv6 *bool
// Implementation is the implementation to use.
// It can be auto, kernelspace, userspace or amneziawg,
// and defaults to auto.
// It can be auto, kernelspace or userspace, and defaults to auto.
Implementation string
// AmneziaWG settings are extra obfuscation parameters
AmneziaWG AmneziaSettings
}
func (s *Settings) SetDefaults() {
@@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) {
}
switch s.Implementation {
case "auto", "kernelspace", "userspace", "amneziawg":
case "auto", "kernelspace", "userspace":
default:
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
}
-58
View File
@@ -1,58 +0,0 @@
package wireguard
import (
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
wgconn "golang.zx2c4.com/wireguard/conn"
wgdevice "golang.zx2c4.com/wireguard/device"
wgtun "golang.zx2c4.com/wireguard/tun"
)
func defaultUserSpaceBackend() userSpaceBackend {
return userSpaceBackend{
createTun: func(name string, mtu int) (tunDevice, error) {
return wgtun.CreateTUN(name, mtu)
},
createBind: func() bind {
return wgconn.NewDefaultBind()
},
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
wgtun, _ := td.(wgtun.Device)
wgBind, _ := b.(wgconn.Bind)
wgLogger := wgdevice.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger)
return device
},
preStart: nil,
}
}
func amneziaUserSpaceBackend() userSpaceBackend {
return userSpaceBackend{
createTun: func(name string, mtu int) (tunDevice, error) {
return amneziatun.CreateTUN(name, mtu)
},
createBind: func() bind {
return amneziaconn.NewDefaultBind()
},
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
wgamneziaTun, _ := td.(amneziatun.Device)
wgamneziaBind, _ := b.(amneziaconn.Bind)
wgamneziaLogger := amneziadevice.Logger{
Verbosef: logger.Debugf,
Errorf: logger.Errorf,
}
device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger)
return device
},
preStart: func(ud userspaceDevice, s Settings) error {
uapiConfig := s.AmneziaWG.uapiConfig()
err := ud.IpcSet(uapiConfig)
return err
},
}
}
+2 -2
View File
@@ -7,10 +7,10 @@ import (
"golang.zx2c4.com/wireguard/ipc"
)
func uapiOpen(name string) (*os.File, error) {
func UAPIOpen(name string) (*os.File, error) {
return ipc.UAPIOpen(name)
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, uapiFile)
}
+2 -2
View File
@@ -7,10 +7,10 @@ import (
"os"
)
func uapiOpen(name string) (*os.File, error) {
func UAPIOpen(name string) (*os.File, error) {
panic("not implemented")
}
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
panic("not implemented")
}