diff --git a/.golangci.yml b/.golangci.yml index 2d15bd21..d24f8fed 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -60,6 +60,10 @@ linters: - linters: - lll source: "^// https://.+$" + - linters: + - mnd + source: "^ cleanups\\.Add.+$" + path: internal\/(wireguard|amneziawg)\/run\.go - linters: - err113 - mnd diff --git a/Dockerfile b/Dockerfile index 1619e998..2c35db99 100644 --- a/Dockerfile +++ b/Dockerfile @@ -112,6 +112,36 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \ WIREGUARD_MTU= \ WIREGUARD_IMPLEMENTATION=auto \ + # Amnezia + AMNEZIAWG_ENDPOINT_IP= \ + AMNEZIAWG_ENDPOINT_PORT= \ + AMNEZIAWG_CONF_SECRETFILE=/run/secrets/wg0.conf \ + AMNEZIAWG_PRIVATE_KEY= \ + AMNEZIAWG_PRIVATE_KEY_SECRETFILE=/run/secrets/wireguard_private_key \ + AMNEZIAWG_PRESHARED_KEY= \ + AMNEZIAWG_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \ + AMNEZIAWG_PUBLIC_KEY= \ + AMNEZIAWG_ALLOWED_IPS= \ + AMNEZIAWG_PERSISTENT_KEEPALIVE_INTERVAL=0 \ + AMNEZIAWG_ADDRESSES= \ + AMNEZIAWG_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \ + AMNEZIAWG_MTU= \ + AMNEZIAWG_JC=0 \ + AMNEZIAWG_JMIN=0 \ + AMNEZIAWG_JMAX=0 \ + AMNEZIAWG_S1=0 \ + AMNEZIAWG_S2=0 \ + AMNEZIAWG_S3=0 \ + AMNEZIAWG_S4=0 \ + AMNEZIAWG_H1= \ + AMNEZIAWG_H2= \ + AMNEZIAWG_H3= \ + AMNEZIAWG_H4= \ + AMNEZIAWG_I1= \ + AMNEZIAWG_I2= \ + AMNEZIAWG_I3= \ + AMNEZIAWG_I4= \ + AMNEZIAWG_I5= \ # Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg) AMNEZIAWG_JC=0 \ AMNEZIAWG_JMIN=0 \ diff --git a/README.md b/README.md index 5a5e1c57..922688e4 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers - For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md) - For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md) - More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134) +- Supports AmneziaWG only with the custom provider for now - DNS over TLS baked in with service provider(s) of your choice - DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours - Choose the vpn network protocol, `udp` or `tcp` diff --git a/internal/amneziawg/constructor.go b/internal/amneziawg/constructor.go new file mode 100644 index 00000000..24987bf0 --- /dev/null +++ b/internal/amneziawg/constructor.go @@ -0,0 +1,22 @@ +package amneziawg + +type Amneziawg struct { + logger Logger + settings Settings + netlink NetLinker +} + +func New(settings Settings, netlink NetLinker, + logger Logger, +) (a *Amneziawg, err error) { + settings.SetDefaults() + if err := settings.Check(); err != nil { + return nil, err + } + + return &Amneziawg{ + logger: logger, + settings: settings, + netlink: netlink, + }, nil +} diff --git a/internal/amneziawg/constructor_test.go b/internal/amneziawg/constructor_test.go new file mode 100644 index 00000000..856af3ad --- /dev/null +++ b/internal/amneziawg/constructor_test.go @@ -0,0 +1,86 @@ +package amneziawg + +import ( + "net/netip" + "testing" + + "github.com/qdm12/gluetun/internal/wireguard" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/device" +) + +func Test_New(t *testing.T) { + t.Parallel() + + const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q=" + logger := NewMockLogger(nil) + netLinker := NewMockNetLinker(nil) + + testCases := map[string]struct { + settings Settings + amneziawg *Amneziawg + err error + }{ + "bad_settings": { + settings: Settings{ + Wireguard: wireguard.Settings{ + PrivateKey: "", + }, + }, + err: wireguard.ErrPrivateKeyMissing, + }, + "minimal valid settings": { + settings: Settings{ + Wireguard: wireguard.Settings{ + PrivateKey: validKeyString, + PublicKey: validKeyString, + Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0), + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), + }, + FirewallMark: 100, + }, + }, + amneziawg: &Amneziawg{ + logger: logger, + netlink: netLinker, + settings: Settings{ + Wireguard: wireguard.Settings{ + InterfaceName: "wg0", + PrivateKey: validKeyString, + PublicKey: validKeyString, + Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), + }, + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/0"), + }, + FirewallMark: 100, + MTU: device.DefaultMTU, + IPv6: ptrTo(false), + Implementation: "auto", + }, + }, + }, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + wireguard, err := New(testCase.settings, netLinker, logger) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.amneziawg, wireguard) + }) + } +} diff --git a/internal/amneziawg/helpers_test.go b/internal/amneziawg/helpers_test.go new file mode 100644 index 00000000..0d5beb37 --- /dev/null +++ b/internal/amneziawg/helpers_test.go @@ -0,0 +1,5 @@ +package amneziawg + +func ptrTo[T any](v T) *T { + return &v +} diff --git a/internal/amneziawg/log.go b/internal/amneziawg/log.go new file mode 100644 index 00000000..9ace7dce --- /dev/null +++ b/internal/amneziawg/log.go @@ -0,0 +1,11 @@ +package amneziawg + +//go:generate mockgen -destination=log_mock_test.go -package amneziawg . Logger + +type Logger interface { + Debug(s string) + Debugf(format string, args ...interface{}) + Info(s string) + Error(s string) + Errorf(format string, args ...interface{}) +} diff --git a/internal/amneziawg/log_mock_test.go b/internal/amneziawg/log_mock_test.go new file mode 100644 index 00000000..1d4153f8 --- /dev/null +++ b/internal/amneziawg/log_mock_test.go @@ -0,0 +1,104 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: Logger) + +// Package amneziawg is a generated GoMock package. +package amneziawg + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0) +} + +// Debugf mocks base method. +func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} + +// Errorf mocks base method. +func (m *MockLogger) Errorf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf. +func (mr *MockLoggerMockRecorder) Errorf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} + +// Info mocks base method. +func (m *MockLogger) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) +} diff --git a/internal/amneziawg/netlinker.go b/internal/amneziawg/netlinker.go new file mode 100644 index 00000000..8be2dcbf --- /dev/null +++ b/internal/amneziawg/netlinker.go @@ -0,0 +1,36 @@ +package amneziawg + +import ( + "net/netip" + + "github.com/qdm12/gluetun/internal/netlink" +) + +//go:generate mockgen -destination=netlinker_mock_test.go -package amneziawg . NetLinker + +type NetLinker interface { + AddrReplace(linkIndex uint32, addr netip.Prefix) error + Router + Ruler + Linker + IsWireguardSupported() (ok bool, err error) +} + +type Router interface { + RouteList(family uint8) (routes []netlink.Route, err error) + RouteAdd(route netlink.Route) error +} + +type Ruler interface { + RuleAdd(rule netlink.Rule) error + RuleDel(rule netlink.Rule) error +} + +type Linker interface { + LinkAdd(link netlink.Link) (linkIndex uint32, err error) + LinkList() (links []netlink.Link, err error) + LinkByName(name string) (link netlink.Link, err error) + LinkSetUp(linkIndex uint32) error + LinkSetDown(linkIndex uint32) error + LinkDel(linkIndex uint32) error +} diff --git a/internal/amneziawg/netlinker_mock_test.go b/internal/amneziawg/netlinker_mock_test.go new file mode 100644 index 00000000..34c9f2c4 --- /dev/null +++ b/internal/amneziawg/netlinker_mock_test.go @@ -0,0 +1,209 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: NetLinker) + +// Package amneziawg is a generated GoMock package. +package amneziawg + +import ( + netip "net/netip" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + netlink "github.com/qdm12/gluetun/internal/netlink" +) + +// MockNetLinker is a mock of NetLinker interface. +type MockNetLinker struct { + ctrl *gomock.Controller + recorder *MockNetLinkerMockRecorder +} + +// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker. +type MockNetLinkerMockRecorder struct { + mock *MockNetLinker +} + +// NewMockNetLinker creates a new mock instance. +func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker { + mock := &MockNetLinker{ctrl: ctrl} + mock.recorder = &MockNetLinkerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder { + return m.recorder +} + +// AddrReplace mocks base method. +func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddrReplace indicates an expected call of AddrReplace. +func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrReplace", reflect.TypeOf((*MockNetLinker)(nil).AddrReplace), arg0, arg1) +} + +// IsWireguardSupported mocks base method. +func (m *MockNetLinker) IsWireguardSupported() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsWireguardSupported") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsWireguardSupported indicates an expected call of IsWireguardSupported. +func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported)) +} + +// LinkAdd mocks base method. +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkAdd", arg0) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkAdd indicates an expected call of LinkAdd. +func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0) +} + +// LinkByName mocks base method. +func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByName", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByName indicates an expected call of LinkByName. +func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0) +} + +// LinkDel mocks base method. +func (m *MockNetLinker) LinkDel(arg0 uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkDel", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkDel indicates an expected call of LinkDel. +func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0) +} + +// LinkList mocks base method. +func (m *MockNetLinker) LinkList() ([]netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkList") + ret0, _ := ret[0].([]netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkList indicates an expected call of LinkList. +func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList)) +} + +// LinkSetDown mocks base method. +func (m *MockNetLinker) LinkSetDown(arg0 uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetDown", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetDown indicates an expected call of LinkSetDown. +func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0) +} + +// LinkSetUp mocks base method. +func (m *MockNetLinker) LinkSetUp(arg0 uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetUp", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetUp indicates an expected call of LinkSetUp. +func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLinker)(nil).LinkSetUp), arg0) +} + +// RouteAdd mocks base method. +func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteAdd", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RouteAdd indicates an expected call of RouteAdd. +func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0) +} + +// RouteList mocks base method. +func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteList", arg0) + ret0, _ := ret[0].([]netlink.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteList indicates an expected call of RouteList. +func (mr *MockNetLinkerMockRecorder) RouteList(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLinker)(nil).RouteList), arg0) +} + +// RuleAdd mocks base method. +func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RuleAdd", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RuleAdd indicates an expected call of RuleAdd. +func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0) +} + +// RuleDel mocks base method. +func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RuleDel", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RuleDel indicates an expected call of RuleDel. +func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0) +} diff --git a/internal/amneziawg/run.go b/internal/amneziawg/run.go new file mode 100644 index 00000000..e85e7ecc --- /dev/null +++ b/internal/amneziawg/run.go @@ -0,0 +1,133 @@ +package amneziawg + +import ( + "context" + "errors" + "fmt" + "net" + + amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn" + amneziadevice "github.com/amnezia-vpn/amneziawg-go/device" + amneziatun "github.com/amnezia-vpn/amneziawg-go/tun" + "github.com/qdm12/gluetun/internal/cleanup" + "github.com/qdm12/gluetun/internal/wireguard" +) + +var ( + errTunNameMismatch = errors.New("TUN device name is mismatching") + errDeviceWaited = errors.New("device waited for") +) + +// Run runs the amneziawg interface and waits until the context is done, then it cleans up the +// interface and returns any error that occurred during setup or waiting. It sends an error to +// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context +// is done. It sends a signal to ready when the setup is complete and the interface is ready to use. +// See https://github.com/amnezia-vpn/amneziawg-go/blob/master/main.go +func (a *Amneziawg) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { + setup := func(ctx context.Context, cleanups *cleanup.Cleanups) ( + linkIndex uint32, waitAndCleanup func() error, err error, + ) { + return setupUserspace(ctx, a.settings.Wireguard.InterfaceName, + a.netlink, a.settings.Wireguard.MTU, cleanups, a.logger, a.settings) + } + + wireguard.Run(ctx, waitError, ready, setup, a.settings.Wireguard, a.netlink, a.logger) +} + +func setupUserspace(ctx context.Context, + interfaceName string, netLinker NetLinker, mtu uint32, + cleanups *cleanup.Cleanups, logger Logger, + settings Settings, +) ( + linkIndex uint32, waitAndCleanup func() error, err error, +) { + tun, err := amneziatun.CreateTUN(interfaceName, int(mtu)) + if err != nil { + return 0, nil, fmt.Errorf("creating TUN device: %w", err) + } + + cleanups.Add("closing TUN device", 7, tun.Close) + + tunName, err := tun.Name() + if err != nil { + return 0, nil, fmt.Errorf("getting created TUN device name: %w", err) + } else if tunName != interfaceName { + return 0, nil, fmt.Errorf("%w: expected %q and got %q", + errTunNameMismatch, interfaceName, tunName) + } + + link, err := netLinker.LinkByName(interfaceName) + if err != nil { + return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err) + } + cleanups.Add("deleting link", 5, func() error { + return netLinker.LinkDel(link.Index) + }) + + bind := amneziaconn.NewDefaultBind() + cleanups.Add("closing bind", 7, bind.Close) + + deviceLogger := amneziadevice.Logger{ + Verbosef: logger.Debugf, + Errorf: logger.Errorf, + } + device := amneziadevice.NewDevice(tun, bind, &deviceLogger) + + cleanups.Add("closing Wireguard device", 6, func() error { + device.Close() + return nil + }) + + uapiFile, err := wireguard.UAPIOpen(interfaceName) + if err != nil { + return 0, nil, fmt.Errorf("opening UAPI socket: %w", err) + } + cleanups.Add("closing UAPI file", 3, uapiFile.Close) + + uapiListener, err := wireguard.UAPIListen(interfaceName, uapiFile) + if err != nil { + return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err) + } + cleanups.Add("closing UAPI listener", 2, uapiListener.Close) + + uapiConfig := settings.uapiConfig() + err = device.IpcSet(uapiConfig) + if err != nil { + return 0, nil, fmt.Errorf("setting amneziawg uapi config: %w", err) + } + + // acceptAndHandle exits when uapiListener is closed + uapiAcceptErrorCh := make(chan error) + go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh) + waitAndCleanup = func() error { + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-uapiAcceptErrorCh: + close(uapiAcceptErrorCh) + case <-device.Wait(): + err = errDeviceWaited + } + + cleanups.Cleanup(logger) + + <-uapiAcceptErrorCh // wait for acceptAndHandle to exit + + return err + } + + return link.Index, waitAndCleanup, nil +} + +func acceptAndHandle(uapi net.Listener, device *amneziadevice.Device, + uapiAcceptErrorCh chan<- error, +) { + for { // stopped by uapiFile.Close() + conn, err := uapi.Accept() + if err != nil { + uapiAcceptErrorCh <- err + return + } + go device.IpcHandle(conn) + } +} diff --git a/internal/wireguard/amnezia_settings.go b/internal/amneziawg/settings.go similarity index 79% rename from internal/wireguard/amnezia_settings.go rename to internal/amneziawg/settings.go index 8b7d670a..b50f022f 100644 --- a/internal/wireguard/amnezia_settings.go +++ b/internal/amneziawg/settings.go @@ -1,11 +1,14 @@ -package wireguard +package amneziawg import ( "fmt" "strings" + + "github.com/qdm12/gluetun/internal/wireguard" ) -type AmneziaSettings struct { +type Settings struct { + Wireguard wireguard.Settings JunkPacketCount uint16 JunkPacketMin uint16 JunkPacketMax uint16 @@ -24,7 +27,7 @@ type AmneziaSettings struct { InitPacketI5 string } -func (s AmneziaSettings) uapiConfig() string { +func (s Settings) uapiConfig() string { uintFields := map[string]uint16{ "jc": s.JunkPacketCount, "jmin": s.JunkPacketMin, @@ -56,3 +59,11 @@ func (s AmneziaSettings) uapiConfig() string { } return strings.Join(lines, "\n") } + +func (s *Settings) SetDefaults() { + s.Wireguard.SetDefaults() +} + +func (s *Settings) Check() error { + return s.Wireguard.Check() +} diff --git a/internal/cleanup/cleanup.go b/internal/cleanup/cleanup.go new file mode 100644 index 00000000..a643575d --- /dev/null +++ b/internal/cleanup/cleanup.go @@ -0,0 +1,51 @@ +package cleanup + +import "sort" + +type Cleanups []cleanup + +type cleanup struct { + operation string + orderIndex uint + cleanup func() error + done bool +} + +// Add adds a cleanup function to the list of cleanups, with a description of the +// operation being cleaned up, and an order index that determines the order in which +// the cleanup functions are run. The lower the order index, the earlier the cleanup +// function is run. +func (c *Cleanups) Add(operation string, orderIndex uint, + cleanupFunc func() error, +) { + closer := cleanup{ + operation: operation, + orderIndex: orderIndex, + cleanup: cleanupFunc, + } + *c = append(*c, closer) +} + +// Cleanup runs the cleanup functions in the order of their orderIndex, +// and logs any error that occurs during cleanup. +// It can also be re-called in case a cleanup fails, and already cleaned up +// functions will not be re-run. +func (c *Cleanups) Cleanup(logger Logger) { + closers := *c + + sort.Slice(closers, func(i, j int) bool { + return closers[i].orderIndex < closers[j].orderIndex + }) + + for i, closer := range closers { + if closer.done { + continue + } + closers[i].done = true + logger.Debug(closer.operation + "...") + err := closer.cleanup() + if err != nil { + logger.Error("failed " + closer.operation + ": " + err.Error()) + } + } +} diff --git a/internal/cleanup/cleanup_test.go b/internal/cleanup/cleanup_test.go new file mode 100644 index 00000000..7871e194 --- /dev/null +++ b/internal/cleanup/cleanup_test.go @@ -0,0 +1,57 @@ +package cleanup + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Test_Cleanups(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + var ACloseCalled, BCloseCalled, CCloseCalled bool + var ( + AErr error + BErr = errors.New("B failed") + CErr = errors.New("C failed") + ) + + var cleanups Cleanups + cleanups.Add("cleaning up A", 5, func() error { + ACloseCalled = true + return AErr + }) + + cleanups.Add("cleaning up B", 3, func() error { + BCloseCalled = true + return BErr + }) + + cleanups.Add("cleaning up C", 2, func() error { + CCloseCalled = true + return CErr + }) + + logger := NewMockLogger(ctrl) + prevCall := logger.EXPECT().Debug("cleaning up C...") + prevCall = logger.EXPECT().Error("failed cleaning up C: C failed").After(prevCall) + prevCall = logger.EXPECT().Debug("cleaning up B...").After(prevCall) + prevCall = logger.EXPECT().Error("failed cleaning up B: B failed").After(prevCall) + logger.EXPECT().Debug("cleaning up A...").After(prevCall) + + cleanups.Cleanup(logger) + + cleanups.Cleanup(logger) // run twice should not close already closed + + for _, cleanup := range cleanups { + assert.True(t, cleanup.done) + } + + assert.True(t, ACloseCalled) + assert.True(t, BCloseCalled) + assert.True(t, CCloseCalled) +} diff --git a/internal/cleanup/interfaces.go b/internal/cleanup/interfaces.go new file mode 100644 index 00000000..47d0d1bb --- /dev/null +++ b/internal/cleanup/interfaces.go @@ -0,0 +1,6 @@ +package cleanup + +type Logger interface { + Debug(string) + Error(string) +} diff --git a/internal/cleanup/mocks_generate_test.go b/internal/cleanup/mocks_generate_test.go new file mode 100644 index 00000000..de57f1da --- /dev/null +++ b/internal/cleanup/mocks_generate_test.go @@ -0,0 +1,3 @@ +package cleanup + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger diff --git a/internal/cleanup/mocks_test.go b/internal/cleanup/mocks_test.go new file mode 100644 index 00000000..2f4a7b4d --- /dev/null +++ b/internal/cleanup/mocks_test.go @@ -0,0 +1,58 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/cleanup (interfaces: Logger) + +// Package cleanup is a generated GoMock package. +package cleanup + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0) +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} diff --git a/internal/configuration/settings/amneziawg.go b/internal/configuration/settings/amneziawg.go index 6311d427..107d32d1 100644 --- a/internal/configuration/settings/amneziawg.go +++ b/internal/configuration/settings/amneziawg.go @@ -12,33 +12,42 @@ import ( ) type AmneziaWg struct { - JunkPacketCount *uint16 `json:"junk_packet_count"` - JunkPacketMin *uint16 `json:"junk_packet_min"` - JunkPacketMax *uint16 `json:"junk_packet_max"` - PaddingS1 *uint16 `json:"padding_s1"` - PaddingS2 *uint16 `json:"padding_s2"` - PaddingS3 *uint16 `json:"padding_s3"` - PaddingS4 *uint16 `json:"padding_s4"` - HeaderH1 *string `json:"header_h1"` - HeaderH2 *string `json:"header_h2"` - HeaderH3 *string `json:"header_h3"` - HeaderH4 *string `json:"header_h4"` - InitPacketI1 *string `json:"init_packet_i1"` - InitPacketI2 *string `json:"init_packet_i2"` - InitPacketI3 *string `json:"init_packet_i3"` - InitPacketI4 *string `json:"init_packet_i4"` - InitPacketI5 *string `json:"init_packet_i5"` + // Wireguard contains the configuration for Wireguard, given + // AmneziaWg is based on Wireguard + Wireguard Wireguard `json:"wireguard"` + JunkPacketCount *uint16 `json:"junk_packet_count"` + JunkPacketMin *uint16 `json:"junk_packet_min"` + JunkPacketMax *uint16 `json:"junk_packet_max"` + PaddingS1 *uint16 `json:"padding_s1"` + PaddingS2 *uint16 `json:"padding_s2"` + PaddingS3 *uint16 `json:"padding_s3"` + PaddingS4 *uint16 `json:"padding_s4"` + HeaderH1 *string `json:"header_h1"` + HeaderH2 *string `json:"header_h2"` + HeaderH3 *string `json:"header_h3"` + HeaderH4 *string `json:"header_h4"` + InitPacketI1 *string `json:"init_packet_i1"` + InitPacketI2 *string `json:"init_packet_i2"` + InitPacketI3 *string `json:"init_packet_i3"` + InitPacketI4 *string `json:"init_packet_i4"` + InitPacketI5 *string `json:"init_packet_i5"` } -func (s *AmneziaWg) read(r *reader.Reader) (err error) { +func (a *AmneziaWg) read(r *reader.Reader) (err error) { + const amneziawg = true + err = a.Wireguard.read(r, amneziawg) + if err != nil { + return err // do not wrap this error + } + uint16Fields := map[string]**uint16{ - "AMNEZIAWG_JC": &s.JunkPacketCount, - "AMNEZIAWG_JMIN": &s.JunkPacketMin, - "AMNEZIAWG_JMAX": &s.JunkPacketMax, - "AMNEZIAWG_S1": &s.PaddingS1, - "AMNEZIAWG_S2": &s.PaddingS2, - "AMNEZIAWG_S3": &s.PaddingS3, - "AMNEZIAWG_S4": &s.PaddingS4, + "AMNEZIAWG_JC": &a.JunkPacketCount, + "AMNEZIAWG_JMIN": &a.JunkPacketMin, + "AMNEZIAWG_JMAX": &a.JunkPacketMax, + "AMNEZIAWG_S1": &a.PaddingS1, + "AMNEZIAWG_S2": &a.PaddingS2, + "AMNEZIAWG_S3": &a.PaddingS3, + "AMNEZIAWG_S4": &a.PaddingS4, } for key, dst := range uint16Fields { *dst, err = r.Uint16Ptr(key) @@ -47,15 +56,15 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) { } } stringFields := map[string]**string{ - "AMNEZIAWG_H1": &s.HeaderH1, - "AMNEZIAWG_H2": &s.HeaderH2, - "AMNEZIAWG_H3": &s.HeaderH3, - "AMNEZIAWG_H4": &s.HeaderH4, - "AMNEZIAWG_I1": &s.InitPacketI1, - "AMNEZIAWG_I2": &s.InitPacketI2, - "AMNEZIAWG_I3": &s.InitPacketI3, - "AMNEZIAWG_I4": &s.InitPacketI4, - "AMNEZIAWG_I5": &s.InitPacketI5, + "AMNEZIAWG_H1": &a.HeaderH1, + "AMNEZIAWG_H2": &a.HeaderH2, + "AMNEZIAWG_H3": &a.HeaderH3, + "AMNEZIAWG_H4": &a.HeaderH4, + "AMNEZIAWG_I1": &a.InitPacketI1, + "AMNEZIAWG_I2": &a.InitPacketI2, + "AMNEZIAWG_I3": &a.InitPacketI3, + "AMNEZIAWG_I4": &a.InitPacketI4, + "AMNEZIAWG_I5": &a.InitPacketI5, } opt := reader.ForceLowercase(false) for key, dst := range stringFields { @@ -64,80 +73,84 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) { return nil } -func (s AmneziaWg) copy() (copied AmneziaWg) { +func (a AmneziaWg) copy() (copied AmneziaWg) { return AmneziaWg{ - JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount), - JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin), - JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax), - PaddingS1: gosettings.CopyPointer(s.PaddingS1), - PaddingS2: gosettings.CopyPointer(s.PaddingS2), - PaddingS3: gosettings.CopyPointer(s.PaddingS3), - PaddingS4: gosettings.CopyPointer(s.PaddingS4), - HeaderH1: gosettings.CopyPointer(s.HeaderH1), - HeaderH2: gosettings.CopyPointer(s.HeaderH2), - HeaderH3: gosettings.CopyPointer(s.HeaderH3), - HeaderH4: gosettings.CopyPointer(s.HeaderH4), - InitPacketI1: gosettings.CopyPointer(s.InitPacketI1), - InitPacketI2: gosettings.CopyPointer(s.InitPacketI2), - InitPacketI3: gosettings.CopyPointer(s.InitPacketI3), - InitPacketI4: gosettings.CopyPointer(s.InitPacketI4), - InitPacketI5: gosettings.CopyPointer(s.InitPacketI5), + Wireguard: a.Wireguard.copy(), + JunkPacketCount: gosettings.CopyPointer(a.JunkPacketCount), + JunkPacketMin: gosettings.CopyPointer(a.JunkPacketMin), + JunkPacketMax: gosettings.CopyPointer(a.JunkPacketMax), + PaddingS1: gosettings.CopyPointer(a.PaddingS1), + PaddingS2: gosettings.CopyPointer(a.PaddingS2), + PaddingS3: gosettings.CopyPointer(a.PaddingS3), + PaddingS4: gosettings.CopyPointer(a.PaddingS4), + HeaderH1: gosettings.CopyPointer(a.HeaderH1), + HeaderH2: gosettings.CopyPointer(a.HeaderH2), + HeaderH3: gosettings.CopyPointer(a.HeaderH3), + HeaderH4: gosettings.CopyPointer(a.HeaderH4), + InitPacketI1: gosettings.CopyPointer(a.InitPacketI1), + InitPacketI2: gosettings.CopyPointer(a.InitPacketI2), + InitPacketI3: gosettings.CopyPointer(a.InitPacketI3), + InitPacketI4: gosettings.CopyPointer(a.InitPacketI4), + InitPacketI5: gosettings.CopyPointer(a.InitPacketI5), } } -//nolint:dupl -func (s *AmneziaWg) overrideWith(other AmneziaWg) { - s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount) - s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin) - s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax) - s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1) - s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2) - s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3) - s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4) - s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1) - s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2) - s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3) - s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4) - s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1) - s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2) - s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3) - s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4) - s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5) +func (a *AmneziaWg) overrideWith(other AmneziaWg) { + a.Wireguard.overrideWith(other.Wireguard) + a.JunkPacketCount = gosettings.OverrideWithPointer(a.JunkPacketCount, other.JunkPacketCount) + a.JunkPacketMin = gosettings.OverrideWithPointer(a.JunkPacketMin, other.JunkPacketMin) + a.JunkPacketMax = gosettings.OverrideWithPointer(a.JunkPacketMax, other.JunkPacketMax) + a.PaddingS1 = gosettings.OverrideWithPointer(a.PaddingS1, other.PaddingS1) + a.PaddingS2 = gosettings.OverrideWithPointer(a.PaddingS2, other.PaddingS2) + a.PaddingS3 = gosettings.OverrideWithPointer(a.PaddingS3, other.PaddingS3) + a.PaddingS4 = gosettings.OverrideWithPointer(a.PaddingS4, other.PaddingS4) + a.HeaderH1 = gosettings.OverrideWithPointer(a.HeaderH1, other.HeaderH1) + a.HeaderH2 = gosettings.OverrideWithPointer(a.HeaderH2, other.HeaderH2) + a.HeaderH3 = gosettings.OverrideWithPointer(a.HeaderH3, other.HeaderH3) + a.HeaderH4 = gosettings.OverrideWithPointer(a.HeaderH4, other.HeaderH4) + a.InitPacketI1 = gosettings.OverrideWithPointer(a.InitPacketI1, other.InitPacketI1) + a.InitPacketI2 = gosettings.OverrideWithPointer(a.InitPacketI2, other.InitPacketI2) + a.InitPacketI3 = gosettings.OverrideWithPointer(a.InitPacketI3, other.InitPacketI3) + a.InitPacketI4 = gosettings.OverrideWithPointer(a.InitPacketI4, other.InitPacketI4) + a.InitPacketI5 = gosettings.OverrideWithPointer(a.InitPacketI5, other.InitPacketI5) } -func (s *AmneziaWg) setDefaults() { - s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0) - s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0) - s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0) - s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0) - s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0) - s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0) - s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0) - s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "") - s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "") - s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "") - s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "") - s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "") - s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "") - s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "") - s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "") - s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "") +func (a *AmneziaWg) setDefaults(vpnProvider string) { + a.Wireguard.setDefaults(vpnProvider) + a.Wireguard.Implementation = "userspace" // unused except in logs + a.JunkPacketCount = gosettings.DefaultPointer(a.JunkPacketCount, 0) + a.JunkPacketMin = gosettings.DefaultPointer(a.JunkPacketMin, 0) + a.JunkPacketMax = gosettings.DefaultPointer(a.JunkPacketMax, 0) + a.PaddingS1 = gosettings.DefaultPointer(a.PaddingS1, 0) + a.PaddingS2 = gosettings.DefaultPointer(a.PaddingS2, 0) + a.PaddingS3 = gosettings.DefaultPointer(a.PaddingS3, 0) + a.PaddingS4 = gosettings.DefaultPointer(a.PaddingS4, 0) + a.HeaderH1 = gosettings.DefaultPointer(a.HeaderH1, "") + a.HeaderH2 = gosettings.DefaultPointer(a.HeaderH2, "") + a.HeaderH3 = gosettings.DefaultPointer(a.HeaderH3, "") + a.HeaderH4 = gosettings.DefaultPointer(a.HeaderH4, "") + a.InitPacketI1 = gosettings.DefaultPointer(a.InitPacketI1, "") + a.InitPacketI2 = gosettings.DefaultPointer(a.InitPacketI2, "") + a.InitPacketI3 = gosettings.DefaultPointer(a.InitPacketI3, "") + a.InitPacketI4 = gosettings.DefaultPointer(a.InitPacketI4, "") + a.InitPacketI5 = gosettings.DefaultPointer(a.InitPacketI5, "") } -func (s AmneziaWg) toLinesNode() (node *gotree.Node) { - node = gotree.New("Amneziawg parameters:") +func (a AmneziaWg) toLinesNode() (node *gotree.Node) { + node = gotree.New("AmneziaWG settings:") + node.AppendNode(a.Wireguard.toLinesNode()) uintFields := []struct { key string val *uint16 }{ - {"jc", s.JunkPacketCount}, - {"jmin", s.JunkPacketMin}, - {"jmax", s.JunkPacketMax}, - {"s1", s.PaddingS1}, - {"s2", s.PaddingS2}, - {"s3", s.PaddingS3}, - {"s4", s.PaddingS4}, + {"JC", a.JunkPacketCount}, + {"JMIN", a.JunkPacketMin}, + {"JMAX", a.JunkPacketMax}, + {"S1", a.PaddingS1}, + {"S2", a.PaddingS2}, + {"S3", a.PaddingS3}, + {"S4", a.PaddingS4}, } for _, f := range uintFields { node.Appendf("%s: %d", f.key, *f.val) @@ -147,15 +160,15 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) { key string val *string }{ - {"h1", s.HeaderH1}, - {"h2", s.HeaderH2}, - {"h3", s.HeaderH3}, - {"h4", s.HeaderH4}, - {"i1", s.InitPacketI1}, - {"i2", s.InitPacketI2}, - {"i3", s.InitPacketI3}, - {"i4", s.InitPacketI4}, - {"i5", s.InitPacketI5}, + {"H1", a.HeaderH1}, + {"H2", a.HeaderH2}, + {"H3", a.HeaderH3}, + {"H4", a.HeaderH4}, + {"I1", a.InitPacketI1}, + {"I2", a.InitPacketI2}, + {"I3", a.InitPacketI3}, + {"I4", a.InitPacketI4}, + {"I5", a.InitPacketI5}, } for _, f := range stringFields { node.Appendf("%s: %s", f.key, *f.val) @@ -165,33 +178,40 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) { } var ( - ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum") - ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set") - ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set") - ErrHeaderRangeMalformed = errors.New("header range is malformed") + ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid") + ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum") + ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set") + ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set") + ErrHeaderRangeMalformed = errors.New("header range is malformed") ) -func (s AmneziaWg) validate() error { - if *s.JunkPacketCount == 0 { - if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 { +func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error { + const amneziaWG = true + err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG) + if err != nil { + return fmt.Errorf("wireguard settings: %w", err) + } + + if *a.JunkPacketCount == 0 { + if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 { return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d", - ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax) + ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax) } } else { - if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 { + if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 { return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d", - ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax) - } else if *s.JunkPacketMin > *s.JunkPacketMax { + ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax) + } else if *a.JunkPacketMin > *a.JunkPacketMax { return fmt.Errorf("%w: jmin=%d and jmax=%d", - ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax) + ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax) } } nameToHeaderRange := map[string]string{ - "h1": *s.HeaderH1, - "h2": *s.HeaderH2, - "h3": *s.HeaderH3, - "h4": *s.HeaderH4, + "h1": *a.HeaderH1, + "h2": *a.HeaderH2, + "h3": *a.HeaderH3, + "h4": *a.HeaderH4, } for name, headerRange := range nameToHeaderRange { if headerRange == "" { diff --git a/internal/configuration/settings/openvpn.go b/internal/configuration/settings/openvpn.go index b826bd49..7cffd4d3 100644 --- a/internal/configuration/settings/openvpn.go +++ b/internal/configuration/settings/openvpn.go @@ -268,8 +268,6 @@ func (o *OpenVPN) copy() (copied OpenVPN) { // overrideWith overrides fields of the receiver // settings object with any field set in the other // settings. -// -//nolint:dupl func (o *OpenVPN) overrideWith(other OpenVPN) { o.Version = gosettings.OverrideWithComparable(o.Version, other.Version) o.User = gosettings.OverrideWithPointer(o.User, other.User) diff --git a/internal/configuration/settings/provider.go b/internal/configuration/settings/provider.go index 4e7f6505..e4a18e77 100644 --- a/internal/configuration/settings/provider.go +++ b/internal/configuration/settings/provider.go @@ -30,7 +30,10 @@ type Provider struct { func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) { // Validate Name var validNames []string - if vpnType == vpn.OpenVPN { + switch vpnType { + case vpn.AmneziaWg: + validNames = []string{providers.Custom} + case vpn.OpenVPN: validNames = providers.AllWithCustom() validNames = append(validNames, "pia") // Retro-compatibility // Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026 @@ -38,7 +41,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex] validNames = validNames[:len(validNames)-1] sort.Strings(validNames) - } else { // Wireguard + case vpn.Wireguard: validNames = []string{ providers.Airvpn, providers.Custom, @@ -52,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet } } if err = validate.IsOneOf(p.Name, validNames...); err != nil { - return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err) + return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err) } err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner) diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index c1056745..af334932 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -87,7 +87,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, filterChoicesGetter FilterChoicesGetter, warner Warner, ) (err error) { switch ss.VPN { - case vpn.OpenVPN, vpn.Wireguard: + case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard: default: return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) } diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index 25336cfd..f0c894b6 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -16,6 +16,7 @@ type VPN struct { // empty string in the internal state. Type string `json:"type"` Provider Provider `json:"provider"` + AmneziaWg AmneziaWg `json:"amneziawg"` OpenVPN OpenVPN `json:"openvpn"` Wireguard Wireguard `json:"wireguard"` PMTUD PMTUD `json:"pmtud"` @@ -29,10 +30,12 @@ type VPN struct { DownCommand *string `json:"down_command"` } +// Validate validates VPN settings, using the filter choices getter (aka servers data storage), +// and if IPv6 is supported or not. // TODO v4 remove pointer for receiver (because of Surfshark). func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) { // Validate Type - validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} + validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard} if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil { return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err) } @@ -42,13 +45,20 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo return fmt.Errorf("provider settings: %w", err) } - if v.Type == vpn.OpenVPN { + switch v.Type { + case vpn.AmneziaWg: + err = v.AmneziaWg.validate(v.Provider.Name, ipv6Supported) + if err != nil { + return fmt.Errorf("AmneziaWG settings: %w", err) + } + case vpn.OpenVPN: err := v.OpenVPN.validate(v.Provider.Name) if err != nil { return fmt.Errorf("OpenVPN settings: %w", err) } - } else { - err := v.Wireguard.validate(v.Provider.Name, ipv6Supported) + case vpn.Wireguard: + const amneziawg = false + err := v.Wireguard.validate(v.Provider.Name, ipv6Supported, amneziawg) if err != nil { return fmt.Errorf("Wireguard settings: %w", err) } @@ -66,6 +76,7 @@ func (v *VPN) Copy() (copied VPN) { return VPN{ Type: v.Type, Provider: v.Provider.copy(), + AmneziaWg: v.AmneziaWg.copy(), OpenVPN: v.OpenVPN.copy(), Wireguard: v.Wireguard.copy(), PMTUD: v.PMTUD.copy(), @@ -77,6 +88,7 @@ func (v *VPN) Copy() (copied VPN) { func (v *VPN) OverrideWith(other VPN) { v.Type = gosettings.OverrideWithComparable(v.Type, other.Type) v.Provider.overrideWith(other.Provider) + v.AmneziaWg.overrideWith(other.AmneziaWg) v.OpenVPN.overrideWith(other.OpenVPN) v.Wireguard.overrideWith(other.Wireguard) v.PMTUD.overrideWith(other.PMTUD) @@ -87,6 +99,7 @@ func (v *VPN) OverrideWith(other VPN) { func (v *VPN) setDefaults() { v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN) v.Provider.setDefaults() + v.AmneziaWg.setDefaults(v.Provider.Name) v.OpenVPN.setDefaults(v.Provider.Name) v.Wireguard.setDefaults(v.Provider.Name) v.PMTUD.setDefaults() @@ -103,9 +116,12 @@ func (v VPN) toLinesNode() (node *gotree.Node) { node.AppendNode(v.Provider.toLinesNode()) - if v.Type == vpn.OpenVPN { + switch v.Type { + case vpn.AmneziaWg: + node.AppendNode(v.AmneziaWg.toLinesNode()) + case vpn.OpenVPN: node.AppendNode(v.OpenVPN.toLinesNode()) - } else { + case vpn.Wireguard: node.AppendNode(v.Wireguard.toLinesNode()) } node.AppendNode(v.PMTUD.toLinesNode()) @@ -128,12 +144,18 @@ func (v *VPN) read(r *reader.Reader) (err error) { return fmt.Errorf("VPN provider: %w", err) } + err = v.AmneziaWg.read(r) + if err != nil { + return fmt.Errorf("AmneziaWG: %w", err) + } + err = v.OpenVPN.read(r) if err != nil { return fmt.Errorf("OpenVPN: %w", err) } - err = v.Wireguard.read(r) + const amneziawg = false + err = v.Wireguard.read(r, amneziawg) if err != nil { return fmt.Errorf("wireguard: %w", err) } diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index e1940f73..f3ecf3cb 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gosettings" "github.com/qdm12/gosettings/reader" @@ -42,34 +41,17 @@ type Wireguard struct { // 0 indicating to use PMTUD. MTU *uint32 `json:"mtu"` // Implementation is the Wireguard implementation to use. - // It can be "auto", "userspace", "kernelspace" or "amneziawg". + // It can be "auto", "userspace" or "kernelspace". // It defaults to "auto" and cannot be the empty string // in the internal state. Implementation string `json:"implementation"` - // AmneziaWG contains obfuscation parameters - AmneziaWG AmneziaWg `json:"amneziawg"` } var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) // Validate validates Wireguard settings. -// It should only be ran if the VPN type chosen is Wireguard. -func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) { - if !helpers.IsOneOf(vpnProvider, - providers.Airvpn, - providers.Custom, - providers.Fastestvpn, - providers.Ivpn, - providers.Mullvad, - providers.Nordvpn, - providers.Protonvpn, - providers.Surfshark, - providers.Windscribe, - ) { - // do not validate for VPN provider not supporting Wireguard - return nil - } - +// It should only be ran if the VPN type chosen is Wireguard or AmneziaWg. +func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) { // Validate PrivateKey if *w.PrivateKey == "" { return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet) @@ -138,14 +120,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName) } - validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"} - if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil { - return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err) - } - - err = w.AmneziaWG.validate() - if err != nil { - return fmt.Errorf("amneziawg settings: %w", err) + if !amneziawg { // amneziawg should have its own Implementation field and ignore this one + validImplementations := []string{"auto", "userspace", "kernelspace"} + if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil { + return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err) + } } return nil @@ -161,7 +140,6 @@ func (w *Wireguard) copy() (copied Wireguard) { Interface: w.Interface, MTU: w.MTU, Implementation: w.Implementation, - AmneziaWG: w.AmneziaWG.copy(), } } @@ -175,7 +153,6 @@ func (w *Wireguard) overrideWith(other Wireguard) { w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface) w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU) w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation) - w.AmneziaWG.overrideWith(other.AmneziaWG) } func (w *Wireguard) setDefaults(vpnProvider string) { @@ -200,7 +177,6 @@ func (w *Wireguard) setDefaults(vpnProvider string) { w.Interface = gosettings.DefaultComparable(w.Interface, "wg0") w.MTU = gosettings.DefaultPointer(w.MTU, 0) w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto") - w.AmneziaWG.setDefaults() } func (w Wireguard) String() string { @@ -242,29 +218,27 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) { } if w.Implementation != "auto" { - implNode := node.Appendf("Implementation: %s", w.Implementation) - - if w.Implementation == "amneziawg" { - implNode.AppendNode(w.AmneziaWG.toLinesNode()) - } + node.Appendf("Implementation: %s", w.Implementation) } return node } -func (w *Wireguard) read(r *reader.Reader) (err error) { - w.PrivateKey = r.Get("WIREGUARD_PRIVATE_KEY", reader.ForceLowercase(false)) - w.PreSharedKey = r.Get("WIREGUARD_PRESHARED_KEY", reader.ForceLowercase(false)) +func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) { + prefix := "WIREGUARD" + if amneziaWG { + prefix = "AMNEZIAWG" + } + w.PrivateKey = r.Get(prefix+"_PRIVATE_KEY", reader.ForceLowercase(false)) + w.PreSharedKey = r.Get(prefix+"_PRESHARED_KEY", reader.ForceLowercase(false)) w.Interface = r.String("VPN_INTERFACE", - reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false)) - w.Implementation = r.String("WIREGUARD_IMPLEMENTATION") + reader.RetroKeys(prefix+"_INTERFACE"), reader.ForceLowercase(false)) - err = w.AmneziaWG.read(r) - if err != nil { - return err + if !amneziaWG { + w.Implementation = r.String("WIREGUARD_IMPLEMENTATION") } - addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS")) + addressStrings := r.CSV(prefix+"_ADDRESSES", reader.RetroKeys(prefix+"_ADDRESS")) // WARNING: do not initialize w.Addresses to an empty slice // or the defaults for nordvpn will not work. for _, addressString := range addressStrings { @@ -279,17 +253,17 @@ func (w *Wireguard) read(r *reader.Reader) (err error) { w.Addresses = append(w.Addresses, address) } - w.AllowedIPs, err = r.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS") + w.AllowedIPs, err = r.CSVNetipPrefixes(prefix + "_ALLOWED_IPS") if err != nil { return err // already wrapped } - w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL") + w.PersistentKeepaliveInterval, err = r.DurationPtr(prefix + "_PERSISTENT_KEEPALIVE_INTERVAL") if err != nil { return err } - w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU") + w.MTU, err = r.Uint32Ptr(prefix + "_MTU") if err != nil { return err } diff --git a/internal/configuration/sources/files/amneziawg.go b/internal/configuration/sources/files/amneziawg.go new file mode 100644 index 00000000..d34a12d9 --- /dev/null +++ b/internal/configuration/sources/files/amneziawg.go @@ -0,0 +1,84 @@ +package files + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "gopkg.in/ini.v1" +) + +func (s *Source) lazyLoadAmneziawgConf() AmneziawgConfig { + if s.cached.amneziawgLoaded { + return s.cached.amneziawgConf + } + + s.cached.amneziawgLoaded = true + var err error + s.cached.amneziawgConf, err = ParseAmneziawgConf(filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf")) + if err != nil { + s.warner.Warnf("skipping Amneziawg config: %s", err) + } + return s.cached.amneziawgConf +} + +type AmneziawgConfig struct { + Wireguard WireguardConfig + Jc *string + Jmin *string + Jmax *string + S1 *string + S2 *string + S3 *string + S4 *string + H1 *string + H2 *string + H3 *string + H4 *string + I1 *string + I2 *string + I3 *string + I4 *string + I5 *string +} + +func ParseAmneziawgConf(path string) (config AmneziawgConfig, err error) { + iniFile, err := ini.InsensitiveLoad(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return AmneziawgConfig{}, nil + } + return AmneziawgConfig{}, fmt.Errorf("loading ini from reader: %w", err) + } + + config.Wireguard, err = ParseWireguardConf(path) + if err != nil { + return AmneziawgConfig{}, err + } + + interfaceSection, err := iniFile.GetSection("Interface") + if err != nil { + // can never happen + return AmneziawgConfig{}, fmt.Errorf("getting interface section: %w", err) + } + + config.Jc = getINIKeyFromSection(interfaceSection, "Jc") + config.Jmin = getINIKeyFromSection(interfaceSection, "Jmin") + config.Jmax = getINIKeyFromSection(interfaceSection, "Jmax") + config.S1 = getINIKeyFromSection(interfaceSection, "S1") + config.S2 = getINIKeyFromSection(interfaceSection, "S2") + config.S3 = getINIKeyFromSection(interfaceSection, "S3") + config.S4 = getINIKeyFromSection(interfaceSection, "S4") + config.H1 = getINIKeyFromSection(interfaceSection, "H1") + config.H2 = getINIKeyFromSection(interfaceSection, "H2") + config.H3 = getINIKeyFromSection(interfaceSection, "H3") + config.H4 = getINIKeyFromSection(interfaceSection, "H4") + config.I1 = getINIKeyFromSection(interfaceSection, "I1") + config.I2 = getINIKeyFromSection(interfaceSection, "I2") + config.I3 = getINIKeyFromSection(interfaceSection, "I3") + config.I4 = getINIKeyFromSection(interfaceSection, "I4") + config.I5 = getINIKeyFromSection(interfaceSection, "I5") + + return config, nil +} diff --git a/internal/configuration/sources/files/amneziawg_test.go b/internal/configuration/sources/files/amneziawg_test.go new file mode 100644 index 00000000..a1c43646 --- /dev/null +++ b/internal/configuration/sources/files/amneziawg_test.go @@ -0,0 +1,82 @@ +package files + +import ( + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Source_ParseAmneziawgConf(t *testing.T) { + t.Parallel() + + t.Run("no_file", func(t *testing.T) { + t.Parallel() + + noFile := filepath.Join(t.TempDir(), "doesnotexist") + wireguard, err := ParseAmneziawgConf(noFile) + assert.Equal(t, AmneziawgConfig{}, wireguard) + assert.NoError(t, err) + }) + + testCases := map[string]struct { + fileContent string + amneziawg AmneziawgConfig + errMessage string + }{ + "ini_load_error": { + fileContent: "invalid", + errMessage: "loading ini from reader: key-value delimiter not found: invalid", + }, + "empty_file": { + errMessage: `getting interface section: section "interface" does not exist`, + }, + "success": { + fileContent: ` +[Interface] +PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8= +Address = 10.38.22.35/32 +DNS = 193.138.218.74 +Jc = 4 +H1 = 721391205 +I1 = + +[Peer] +PresharedKey = YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g= +`, + amneziawg: AmneziawgConfig{ + Wireguard: WireguardConfig{ + PrivateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="), + PreSharedKey: ptrTo("YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g="), + Addresses: ptrTo("10.38.22.35/32"), + }, + Jc: ptrTo("4"), + H1: ptrTo("721391205"), + I1: ptrTo(""), + }, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + t.Parallel() + + configFile := filepath.Join(t.TempDir(), "awg.conf") + const permission = fs.FileMode(0o600) + err := os.WriteFile(configFile, []byte(testCase.fileContent), permission) + require.NoError(t, err) + + wireguard, err := ParseAmneziawgConf(configFile) + + assert.Equal(t, testCase.amneziawg, wireguard) + if testCase.errMessage != "" { + assert.EqualError(t, err, testCase.errMessage) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/configuration/sources/files/reader.go b/internal/configuration/sources/files/reader.go index 3ca658c7..3f4575da 100644 --- a/internal/configuration/sources/files/reader.go +++ b/internal/configuration/sources/files/reader.go @@ -13,6 +13,8 @@ type Source struct { cached struct { wireguardLoaded bool wireguardConf WireguardConfig + amneziawgLoaded bool + amneziawgConf AmneziawgConfig } } @@ -69,38 +71,11 @@ func (s *Source) Get(key string) (value string, isSet bool) { return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP) case "wireguard_endpoint_port": return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort) - case "wireguard_jc": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc) - case "wireguard_jmin": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin) - case "wireguard_jmax": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax) - case "wireguard_s1": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1) - case "wireguard_s2": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2) - case "wireguard_s3": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3) - case "wireguard_s4": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4) - case "wireguard_h1": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1) - case "wireguard_h2": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2) - case "wireguard_h3": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3) - case "wireguard_h4": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4) - case "wireguard_i1": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1) - case "wireguard_i2": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2) - case "wireguard_i3": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3) - case "wireguard_i4": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4) - case "wireguard_i5": - return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5) + } + + value, isSet, matched := s.getAmneziawgKey(key) + if matched { + return value, isSet } value, isSet, err := ReadFromFile(path) @@ -110,6 +85,58 @@ func (s *Source) Get(key string) (value string, isSet bool) { return value, isSet } +func (s *Source) getAmneziawgKey(key string) (value string, isSet, matched bool) { + switch key { + case "amnezia_private_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey) + case "amnezia_preshared_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey) + case "amnezia_addresses": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses) + case "amnezia_public_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey) + case "amnezia_endpoint_ip": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP) + case "amnezia_endpoint_port": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort) + case "amnezia_jc": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc) + case "amnezia_jmin": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin) + case "amnezia_jmax": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax) + case "amnezia_s1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1) + case "amnezia_s2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2) + case "amnezia_s3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3) + case "amnezia_s4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4) + case "amnezia_h1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1) + case "amnezia_h2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2) + case "amnezia_h3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3) + case "amnezia_h4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4) + case "amnezia_i1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1) + case "amnezia_i2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2) + case "amnezia_i3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3) + case "amnezia_i4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4) + case "amnezia_i5": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5) + default: + return "", false, false + } + return value, isSet, true +} + func (s *Source) KeyTransform(key string) string { switch key { // TODO v4 remove these irregular cases diff --git a/internal/configuration/sources/files/wireguard.go b/internal/configuration/sources/files/wireguard.go index d65310bb..0b86f662 100644 --- a/internal/configuration/sources/files/wireguard.go +++ b/internal/configuration/sources/files/wireguard.go @@ -25,54 +25,13 @@ func (s *Source) lazyLoadWireguardConf() WireguardConfig { return s.cached.wireguardConf } -type amneziaWgConfig struct { - Jc *string - Jmin *string - Jmax *string - S1 *string - S2 *string - S3 *string - S4 *string - H1 *string - H2 *string - H3 *string - H4 *string - I1 *string - I2 *string - I3 *string - I4 *string - I5 *string -} - -func parseWireguardAmneziaInterfaceSection(interfaceSection *ini.Section) amneziaWgConfig { - return amneziaWgConfig{ - Jc: getINIKeyFromSection(interfaceSection, "Jc"), - Jmin: getINIKeyFromSection(interfaceSection, "Jmin"), - Jmax: getINIKeyFromSection(interfaceSection, "Jmax"), - S1: getINIKeyFromSection(interfaceSection, "S1"), - S2: getINIKeyFromSection(interfaceSection, "S2"), - S3: getINIKeyFromSection(interfaceSection, "S3"), - S4: getINIKeyFromSection(interfaceSection, "S4"), - H1: getINIKeyFromSection(interfaceSection, "H1"), - H2: getINIKeyFromSection(interfaceSection, "H2"), - H3: getINIKeyFromSection(interfaceSection, "H3"), - H4: getINIKeyFromSection(interfaceSection, "H4"), - I1: getINIKeyFromSection(interfaceSection, "I1"), - I2: getINIKeyFromSection(interfaceSection, "I2"), - I3: getINIKeyFromSection(interfaceSection, "I3"), - I4: getINIKeyFromSection(interfaceSection, "I4"), - I5: getINIKeyFromSection(interfaceSection, "I5"), - } -} - type WireguardConfig struct { - PrivateKey *string - PreSharedKey *string - Addresses *string - PublicKey *string - EndpointIP *string - EndpointPort *string - AmneziaParams amneziaWgConfig + PrivateKey *string + PreSharedKey *string + Addresses *string + PublicKey *string + EndpointIP *string + EndpointPort *string } var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`) @@ -89,7 +48,6 @@ func ParseWireguardConf(path string) (config WireguardConfig, err error) { interfaceSection, err := iniFile.GetSection("Interface") if err == nil { config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection) - config.AmneziaParams = parseWireguardAmneziaInterfaceSection(interfaceSection) } else if !regexINISectionNotExist.MatchString(err.Error()) { // can never happen return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err) diff --git a/internal/configuration/sources/files/wireguard_test.go b/internal/configuration/sources/files/wireguard_test.go index 2fb78de1..53d9c838 100644 --- a/internal/configuration/sources/files/wireguard_test.go +++ b/internal/configuration/sources/files/wireguard_test.go @@ -97,10 +97,9 @@ func Test_parseWireguardInterfaceSection(t *testing.T) { t.Parallel() testCases := map[string]struct { - iniData string - privateKey *string - addresses *string - amneziaParams amneziaWgConfig + iniData string + privateKey *string + addresses *string }{ "no_fields": { iniData: `[Interface]`, @@ -116,17 +115,9 @@ PrivateKey = x [Interface] PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8= Address = 10.38.22.35/32 -Jc = 4 -H1 = 721391205 -I1 = `, privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="), addresses: ptrTo("10.38.22.35/32"), - amneziaParams: amneziaWgConfig{ - Jc: ptrTo("4"), - H1: ptrTo("721391205"), - I1: ptrTo(""), - }, }, } @@ -140,11 +131,9 @@ I1 = require.NoError(t, err) privateKey, addresses := parseWireguardInterfaceSection(iniSection) - amneziaWgConfig := parseWireguardAmneziaInterfaceSection(iniSection) assert.Equal(t, testCase.privateKey, privateKey) assert.Equal(t, testCase.addresses, addresses) - assert.Equal(t, testCase.amneziaParams, amneziaWgConfig) }) } } diff --git a/internal/configuration/sources/secrets/amneziawg.go b/internal/configuration/sources/secrets/amneziawg.go new file mode 100644 index 00000000..53eba65d --- /dev/null +++ b/internal/configuration/sources/secrets/amneziawg.go @@ -0,0 +1,27 @@ +package secrets + +import ( + "os" + "path/filepath" + + "github.com/qdm12/gluetun/internal/configuration/sources/files" +) + +func (s *Source) lazyLoadAmneziawgConf() files.AmneziawgConfig { + if s.cached.amneziawgLoaded { + return s.cached.amneziawgConf + } + + path := os.Getenv("AMNEZIAWG_CONF_SECRETFILE") + if path == "" { + path = filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf") + } + + s.cached.amneziawgLoaded = true + var err error + s.cached.amneziawgConf, err = files.ParseAmneziawgConf(path) + if err != nil { + s.warner.Warnf("skipping Amneziawg config: %s", err) + } + return s.cached.amneziawgConf +} diff --git a/internal/configuration/sources/secrets/reader.go b/internal/configuration/sources/secrets/reader.go index d03126d6..d1c329a7 100644 --- a/internal/configuration/sources/secrets/reader.go +++ b/internal/configuration/sources/secrets/reader.go @@ -15,6 +15,8 @@ type Source struct { cached struct { wireguardLoaded bool wireguardConf files.WireguardConfig + amneziawgLoaded bool + amneziawgConf files.AmneziawgConfig } } @@ -60,26 +62,26 @@ func (s *Source) Get(key string) (value string, isSet bool) { s.warner.Warnf("skipping %s: parsing PEM: %s", path, err) } return value, isSet - case "wireguard_private_key": + case "wireguard_private_key", "amneziawg_private_key": privateKey := s.lazyLoadWireguardConf().PrivateKey if privateKey != nil { return *privateKey, true } // else continue to read from individual secret file - case "wireguard_preshared_key": + case "wireguard_preshared_key", "amneziawg_preshared_key": preSharedKey := s.lazyLoadWireguardConf().PreSharedKey if preSharedKey != nil { return *preSharedKey, true } // else continue to read from individual secret file - case "wireguard_addresses": + case "wireguard_addresses", "amneziawg_addresses": addresses := s.lazyLoadWireguardConf().Addresses if addresses != nil { return *addresses, true } // else continue to read from individual secret file - case "wireguard_public_key": + case "wireguard_public_key", "amneziawg_public_key": return strPtrToStringIsSet(s.lazyLoadWireguardConf().PublicKey) - case "wireguard_endpoint_ip": + case "wireguard_endpoint_ip", "amneziawg_endpoint_ip": return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP) - case "wireguard_endpoint_port": + case "wireguard_endpoint_port", "amneziawg_endpoint_port": return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort) } @@ -112,38 +114,50 @@ func (s *Source) KeyTransform(key string) string { func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) { switch key { - case "wireguard_jc": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc) - case "wireguard_jmin": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin) - case "wireguard_jmax": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax) - case "wireguard_s1": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1) - case "wireguard_s2": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2) - case "wireguard_s3": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3) - case "wireguard_s4": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4) - case "wireguard_h1": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1) - case "wireguard_h2": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2) - case "wireguard_h3": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3) - case "wireguard_h4": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4) - case "wireguard_i1": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1) - case "wireguard_i2": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2) - case "wireguard_i3": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3) - case "wireguard_i4": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4) - case "wireguard_i5": - value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5) + case "amneziawg_private_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey) + case "amneziawg_preshared_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey) + case "wireguard_addresses", "amneziawg_addresses": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses) + case "wireguard_public_key", "amneziawg_public_key": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey) + case "wireguard_endpoint_ip", "amneziawg_endpoint_ip": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP) + case "wireguard_endpoint_port", "amneziawg_endpoint_port": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort) + case "amneziawg_jc": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc) + case "amneziawg_jmin": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin) + case "amneziawg_jmax": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax) + case "amneziawg_s1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1) + case "amneziawg_s2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2) + case "amneziawg_s3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3) + case "amneziawg_s4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4) + case "amneziawg_h1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1) + case "amneziawg_h2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2) + case "amneziawg_h3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3) + case "amneziawg_h4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4) + case "amneziawg_i1": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1) + case "amneziawg_i2": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2) + case "amneziawg_i3": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3) + case "amneziawg_i4": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4) + case "amneziawg_i5": + value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5) default: return "", false, false } diff --git a/internal/constants/vpn/protocol.go b/internal/constants/vpn/protocol.go index 1fd36b88..48daf0f2 100644 --- a/internal/constants/vpn/protocol.go +++ b/internal/constants/vpn/protocol.go @@ -1,6 +1,7 @@ package vpn const ( + AmneziaWg = "amneziawg" OpenVPN = "openvpn" Wireguard = "wireguard" ) diff --git a/internal/vpn/amneziawg.go b/internal/vpn/amneziawg.go new file mode 100644 index 00000000..7e7a7966 --- /dev/null +++ b/internal/vpn/amneziawg.go @@ -0,0 +1,67 @@ +package vpn + +import ( + "context" + "fmt" + + "github.com/qdm12/gluetun/internal/amneziawg" + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider" + "github.com/qdm12/gluetun/internal/wireguard" + "github.com/qdm12/gosettings" +) + +// setupAmneziaWg sets AmneziaWG up using the configurators and settings given. +func setupAmneziaWg(ctx context.Context, netlinker NetLinker, + fw Firewall, providerConf provider.Provider, + settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( + amneziawger *amneziawg.Amneziawg, connection models.Connection, err error, +) { + connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) + if err != nil { + return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err) + } + + amneziaWGSettings := buildAmneziaWgSettings(connection, settings.AmneziaWg, ipv6Supported) + + logger.Debug("Amneziawg server public key: " + amneziaWGSettings.Wireguard.PublicKey) + logger.Debug("Amneziawg client private key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PrivateKey)) + logger.Debug("Amneziawg pre-shared key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PreSharedKey)) + + amneziawger, err = amneziawg.New(amneziaWGSettings, netlinker, logger) + if err != nil { + return nil, models.Connection{}, fmt.Errorf("creating amneziawg: %w", err) + } + + err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) + if err != nil { + return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err) + } + + return amneziawger, connection, nil +} + +func buildAmneziaWgSettings(connection models.Connection, + userSettings settings.AmneziaWg, ipv6Supported bool, +) amneziawg.Settings { + return amneziawg.Settings{ + Wireguard: buildWireguardSettings(connection, userSettings.Wireguard, ipv6Supported), + JunkPacketCount: *userSettings.JunkPacketCount, + JunkPacketMin: *userSettings.JunkPacketMin, + JunkPacketMax: *userSettings.JunkPacketMax, + PaddingS1: *userSettings.PaddingS1, + PaddingS2: *userSettings.PaddingS2, + PaddingS3: *userSettings.PaddingS3, + PaddingS4: *userSettings.PaddingS4, + HeaderH1: *userSettings.HeaderH1, + HeaderH2: *userSettings.HeaderH2, + HeaderH3: *userSettings.HeaderH3, + HeaderH4: *userSettings.HeaderH4, + InitPacketI1: *userSettings.InitPacketI1, + InitPacketI2: *userSettings.InitPacketI2, + InitPacketI3: *userSettings.InitPacketI3, + InitPacketI4: *userSettings.InitPacketI4, + InitPacketI5: *userSettings.InitPacketI5, + } +} diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 2fe33c03..c0b8dd3b 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -33,14 +33,21 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { var connection models.Connection var err error subLogger := l.logger.New(log.SetComponent(settings.Type)) - if settings.Type == vpn.OpenVPN { + switch settings.Type { + case vpn.AmneziaWg: + vpnInterface = settings.AmneziaWg.Wireguard.Interface + vpnRunner, connection, err = setupAmneziaWg(ctx, l.netLinker, l.fw, + providerConf, settings, l.ipv6Supported, subLogger) + case vpn.OpenVPN: vpnInterface = settings.OpenVPN.Interface vpnRunner, connection, err = setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, settings, l.ipv6Supported, l.cmder, subLogger) - } else { // Wireguard + case vpn.Wireguard: vpnInterface = settings.Wireguard.Interface vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.ipv6Supported, subLogger) + default: + panic("vpn type not implemented: " + settings.Type) } if err != nil { l.crashed(ctx, err) diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 186375a1..372345d1 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -8,6 +8,7 @@ import ( "time" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/pmtud" pconstants "github.com/qdm12/gluetun/internal/pmtud/constants" @@ -48,6 +49,14 @@ type tunnelUpPMTUDData struct { } func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { + switch vpnType := l.GetSettings().Type; vpnType { + case vpn.Wireguard, vpn.AmneziaWg: + l.logger.Infof("%s setup is complete. "+ + "Note %s is a silent protocol and it may or may not work, without giving any error message. "+ + "Typically i/o timeout errors indicate the %s connection is not working.", + vpnType, vpnType, vpnType) + } + l.client.CloseIdleConnections() for _, vpnPort := range l.vpnInputPorts { diff --git a/internal/vpn/wireguard.go b/internal/vpn/wireguard.go index c7977522..1151b304 100644 --- a/internal/vpn/wireguard.go +++ b/internal/vpn/wireguard.go @@ -51,7 +51,6 @@ func buildWireguardSettings(connection models.Connection, settings.PreSharedKey = *userSettings.PreSharedKey settings.InterfaceName = userSettings.Interface settings.Implementation = userSettings.Implementation - settings.AmneziaWG = buildAmneziaWgSettings(userSettings.AmneziaWG) if *userSettings.MTU > 0 { settings.MTU = *userSettings.MTU } else { @@ -91,24 +90,3 @@ func buildWireguardSettings(connection models.Connection, return settings } - -func buildAmneziaWgSettings(s settings.AmneziaWg) wireguard.AmneziaSettings { - return wireguard.AmneziaSettings{ - JunkPacketCount: *s.JunkPacketCount, - JunkPacketMin: *s.JunkPacketMin, - JunkPacketMax: *s.JunkPacketMax, - PaddingS1: *s.PaddingS1, - PaddingS2: *s.PaddingS2, - PaddingS3: *s.PaddingS3, - PaddingS4: *s.PaddingS4, - HeaderH1: *s.HeaderH1, - HeaderH2: *s.HeaderH2, - HeaderH3: *s.HeaderH3, - HeaderH4: *s.HeaderH4, - InitPacketI1: *s.InitPacketI1, - InitPacketI2: *s.InitPacketI2, - InitPacketI3: *s.InitPacketI3, - InitPacketI4: *s.InitPacketI4, - InitPacketI5: *s.InitPacketI5, - } -} diff --git a/internal/vpn/wireguard_test.go b/internal/vpn/wireguard_test.go index 2d07b0ea..07417923 100644 --- a/internal/vpn/wireguard_test.go +++ b/internal/vpn/wireguard_test.go @@ -40,24 +40,6 @@ func Test_buildWireguardSettings(t *testing.T) { PersistentKeepaliveInterval: ptrTo(time.Hour), Interface: "wg1", MTU: ptrTo(uint32(1000)), - AmneziaWG: settings.AmneziaWg{ - JunkPacketCount: ptrTo(uint16(1)), - JunkPacketMin: ptrTo(uint16(0)), - JunkPacketMax: ptrTo(uint16(0)), - PaddingS1: ptrTo(uint16(0)), - PaddingS2: ptrTo(uint16(0)), - PaddingS3: ptrTo(uint16(0)), - PaddingS4: ptrTo(uint16(0)), - HeaderH1: ptrTo("x"), - HeaderH2: ptrTo(""), - HeaderH3: ptrTo(""), - HeaderH4: ptrTo(""), - InitPacketI1: ptrTo(""), - InitPacketI2: ptrTo(""), - InitPacketI3: ptrTo(""), - InitPacketI4: ptrTo(""), - InitPacketI5: ptrTo(""), - }, }, ipv6Supported: false, settings: wireguard.Settings{ @@ -76,10 +58,6 @@ func Test_buildWireguardSettings(t *testing.T) { RulePriority: 101, IPv6: ptrTo(false), MTU: 1000, - AmneziaWG: wireguard.AmneziaSettings{ - JunkPacketCount: 1, - HeaderH1: "x", - }, }, }, } diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index a85cd7f8..3c2291bd 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -5,15 +5,16 @@ import ( "net/netip" ) -func (w *Wireguard) addAddresses(linkIndex uint32, - addresses []netip.Prefix, +func AddAddresses(linkIndex uint32, + addresses []netip.Prefix, ipv6 bool, + netlink NetLinker, ) (err error) { for _, address := range addresses { - if !*w.settings.IPv6 && address.Addr().Is6() { + if !ipv6 && address.Addr().Is6() { continue } - err = w.netlink.AddrReplace(linkIndex, address) + err = netlink.AddrReplace(linkIndex, address) if err != nil { return fmt.Errorf("%w: when adding address %s to link with index %d", err, address, linkIndex) diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index 20707c4c..8c3647c9 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_Wireguard_addAddresses(t *testing.T) { +func Test_AddAddresses(t *testing.T) { t.Parallel() ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32) @@ -19,15 +19,17 @@ func Test_Wireguard_addAddresses(t *testing.T) { errDummy := errors.New("dummy") testCases := map[string]struct { - linkIndex uint32 - addrs []netip.Prefix - wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard - err error + linkIndex uint32 + addrs []netip.Prefix + ipv6 bool + netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker + err error }{ "success": { linkIndex: 1, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { + ipv6: true, + netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). AddrReplace(linkIndex, ipNetOne). @@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) { netLinker.EXPECT(). AddrReplace(linkIndex, ipNetTwo). Return(nil).After(firstCall) - return &Wireguard{ - netlink: netLinker, - settings: Settings{ - IPv6: ptrTo(true), - }, - } + return netLinker }, }, "first add error": { linkIndex: 1, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { + ipv6: true, + netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker { netLinker := NewMockNetLinker(ctrl) netLinker.EXPECT(). AddrReplace(linkIndex, ipNetOne). Return(errDummy) - return &Wireguard{ - netlink: netLinker, - settings: Settings{ - IPv6: ptrTo(true), - }, - } + return netLinker }, err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"), }, "second add error": { linkIndex: 1, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, - wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard { + ipv6: true, + netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). AddrReplace(linkIndex, ipNetOne). @@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) { netLinker.EXPECT(). AddrReplace(linkIndex, ipNetTwo). Return(errDummy).After(firstCall) - return &Wireguard{ - netlink: netLinker, - settings: Settings{ - IPv6: ptrTo(true), - }, - } + return netLinker }, err: errors.New("dummy: when adding address ::1234/64 to link with index 1"), }, "ignore IPv6": { addrs: []netip.Prefix{ipNetTwo}, - wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard { - return &Wireguard{ - settings: Settings{ - IPv6: ptrTo(false), - }, - } + netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker { + return NewMockNetLinker(nil) }, }, } @@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - wg := testCase.wgBuilder(ctrl, testCase.linkIndex) + netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex) - err := wg.addAddresses(testCase.linkIndex, testCase.addrs) + err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink) if testCase.err != nil { require.Error(t, err) diff --git a/internal/wireguard/cleanup.go b/internal/wireguard/cleanup.go deleted file mode 100644 index 668682fe..00000000 --- a/internal/wireguard/cleanup.go +++ /dev/null @@ -1,63 +0,0 @@ -package wireguard - -import "sort" - -type closer struct { - operation string - step step - close func() error - closed bool -} - -type closers []closer - -func (c *closers) add(operation string, step step, - closeFunc func() error, -) { - closer := closer{ - operation: operation, - step: step, - close: closeFunc, - } - *c = append(*c, closer) -} - -func (c *closers) cleanup(logger Logger) { - closers := *c - - sort.Slice(closers, func(i, j int) bool { - return closers[i].step < closers[j].step - }) - - for i, closer := range closers { - if closer.closed { - continue - } - closers[i].closed = true - logger.Debug(closer.operation + "...") - err := closer.close() - if err != nil { - logger.Error("failed " + closer.operation + ": " + err.Error()) - } - } -} - -type step int - -const ( - // stepOne closes the wireguard controller client, - // and removes the IP rule. - stepOne step = iota - // stepTwo closes the UAPI listener. - stepTwo - // stepThree closes the UAPI file. - stepThree - // stepFour shuts down the Wireguard link. - stepFour - // stepFive removes the Wireguard link. - stepFive - // stepSix closes the Wireguard device. - stepSix - // stepSeven closes the bind connection and the TUN device file. - stepSeven -) diff --git a/internal/wireguard/cleanup_test.go b/internal/wireguard/cleanup_test.go deleted file mode 100644 index 4968ad96..00000000 --- a/internal/wireguard/cleanup_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package wireguard - -import ( - "errors" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func Test_closers(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - - var ACloseCalled, BCloseCalled, CCloseCalled bool - var ( - AErr error - BErr = errors.New("B failed") - CErr = errors.New("C failed") - ) - - var closers closers - closers.add("closing A", stepFive, func() error { - ACloseCalled = true - return AErr - }) - - closers.add("closing B", stepThree, func() error { - BCloseCalled = true - return BErr - }) - - closers.add("closing C", stepTwo, func() error { - CCloseCalled = true - return CErr - }) - - logger := NewMockLogger(ctrl) - prevCall := logger.EXPECT().Debug("closing C...") - prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall) - prevCall = logger.EXPECT().Debug("closing B...").After(prevCall) - prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall) - logger.EXPECT().Debug("closing A...").After(prevCall) - - closers.cleanup(logger) - - closers.cleanup(logger) // run twice should not close already closed - - for _, closer := range closers { - assert.True(t, closer.closed) - } - - assert.True(t, ACloseCalled) - assert.True(t, BCloseCalled) - assert.True(t, CCloseCalled) -} diff --git a/internal/wireguard/common.go b/internal/wireguard/common.go deleted file mode 100644 index a4d9fb56..00000000 --- a/internal/wireguard/common.go +++ /dev/null @@ -1,28 +0,0 @@ -package wireguard - -import ( - "net" -) - -type tunDevice interface { - Close() error - Name() (string, error) -} - -type bind interface { - Close() error -} - -type userspaceDevice interface { - Close() - Wait() chan struct{} - IpcHandle(net.Conn) - IpcSet(string) error -} - -type userSpaceBackend struct { - createTun func(string, int) (tunDevice, error) - createBind func() bind - createDevice func(tunDevice, bind, Logger) userspaceDevice - preStart func(userspaceDevice, Settings) error -} diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go index 735a5973..a0752cd0 100644 --- a/internal/wireguard/config.go +++ b/internal/wireguard/config.go @@ -10,7 +10,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -func configureDevice(client *wgctrl.Client, settings Settings) (err error) { +func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) { deviceConfig, err := makeDeviceConfig(settings) if err != nil { return fmt.Errorf("making device configuration: %w", err) diff --git a/internal/wireguard/log.go b/internal/wireguard/log.go index c5120dd0..14a22e09 100644 --- a/internal/wireguard/log.go +++ b/internal/wireguard/log.go @@ -1,5 +1,9 @@ package wireguard +import ( + "golang.zx2c4.com/wireguard/device" +) + //go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger type Logger interface { @@ -7,5 +11,16 @@ type Logger interface { Debugf(format string, args ...interface{}) Info(s string) Error(s string) - Errorf(format string, args ...interface{}) + Erroer +} + +type Erroer interface { + Errorf(format string, args ...any) +} + +func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) { + return &device.Logger{ + Verbosef: logger.Debugf, + Errorf: logger.Errorf, + } } diff --git a/internal/wireguard/log_test.go b/internal/wireguard/log_test.go new file mode 100644 index 00000000..1491dd34 --- /dev/null +++ b/internal/wireguard/log_test.go @@ -0,0 +1,23 @@ +package wireguard + +import ( + "testing" + + "github.com/golang/mock/gomock" +) + +func Test_makeDeviceLogger(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + logger := NewMockLogger(ctrl) + + deviceLogger := makeDeviceLogger(logger) + + logger.EXPECT().Debugf("test %d", 1) + deviceLogger.Verbosef("test %d", 1) + + logger.EXPECT().Errorf("test %d", 2) + deviceLogger.Errorf("test %d", 2) +} diff --git a/internal/wireguard/netlink_integration_test.go b/internal/wireguard/netlink_integration_test.go index e4883799..d57806a1 100644 --- a/internal/wireguard/netlink_integration_test.go +++ b/internal/wireguard/netlink_integration_test.go @@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {} func (n noopDebugLogger) Errorf(_ string, _ ...any) {} func (n noopDebugLogger) Patch(_ ...log.Option) {} -func Test_netlink_Wireguard_addAddresses(t *testing.T) { +func Test_AddAddresses_Integration(t *testing.T) { t.Parallel() netlinker := netlink.New(&noopDebugLogger{}) @@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { const addIterations = 2 // initial + replace for range addIterations { - err = wg.addAddresses(link.Index, addresses) + err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink) require.NoError(t, err) ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll) @@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { } } -func Test_netlink_Wireguard_addRule(t *testing.T) { +func Test_AddRule_Integration(t *testing.T) { t.Parallel() - netlinker := netlink.New(&noopDebugLogger{}) - wg := &Wireguard{ - netlink: netlinker, - logger: &noopDebugLogger{}, - } + logger := &noopDebugLogger{} + netlinker := netlink.New(logger) // Unique combination for this test const rulePriority uint32 = 10000 const firewallMark uint32 = 12345 const family = netlink.FamilyV4 - cleanup, err := wg.addRule(rulePriority, - firewallMark, family) + cleanup, err := AddRule(rulePriority, + firewallMark, family, netlinker, logger) require.NoError(t, err) t.Cleanup(func() { err := cleanup() @@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { require.True(t, ruleFound) // Existing rule cannot be added - nilCleanup, err := wg.addRule(rulePriority, - firewallMark, family) + nilCleanup, err := AddRule(rulePriority, + firewallMark, family, netlinker, logger) if nilCleanup != nil { _ = nilCleanup() // in case it succeeds } diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index 101b39fd..3fecf2c3 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -8,17 +8,17 @@ import ( "github.com/qdm12/gluetun/internal/netlink" ) -func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix, - firewallMark uint32, +func AddRoutes(linkIndex uint32, destinations []netip.Prefix, + firewallMark uint32, netlinker NetLinker, logger Erroer, ) (err error) { for _, dst := range destinations { - err = w.addRoute(linkIndex, dst, firewallMark) + err = addRoute(linkIndex, dst, firewallMark, netlinker) if err == nil { continue } if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") { - w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+ + logger.Errorf("cannot add route for IPv6 due to a permission denial. "+ "Ignoring and continuing execution; "+ "Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+ "Full error string: %s", err) @@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix, return nil } -func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix, - firewallMark uint32, +func addRoute(linkIndex uint32, dst netip.Prefix, + firewallMark uint32, netlinker NetLinker, ) (err error) { family := netlink.FamilyV4 if dst.Addr().Is6() { @@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix, Proto: netlink.ProtoStatic, } - err = w.netlink.RouteAdd(route) + err = netlinker.RouteAdd(route) if err != nil { return fmt.Errorf( "adding route for link with index %d, destination %s and table %d: %w", diff --git a/internal/wireguard/route_test.go b/internal/wireguard/route_test.go index c75c8034..0fc37832 100644 --- a/internal/wireguard/route_test.go +++ b/internal/wireguard/route_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_Wireguard_addRoute(t *testing.T) { +func Test_addRoute(t *testing.T) { t.Parallel() const linkIndex = 88 @@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) { ctrl := gomock.NewController(t) netLinker := NewMockNetLinker(ctrl) - wg := Wireguard{ - netlink: netLinker, - } - netLinker.EXPECT(). RouteAdd(testCase.expectedRoute). Return(testCase.routeAddErr) - err := wg.addRoute(linkIndex, testCase.dst, firewallMark) + err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker) if testCase.err != nil { require.Error(t, err) diff --git a/internal/wireguard/rule.go b/internal/wireguard/rule.go index 24249cb5..3319c5c9 100644 --- a/internal/wireguard/rule.go +++ b/internal/wireguard/rule.go @@ -7,8 +7,8 @@ import ( "github.com/qdm12/gluetun/internal/netlink" ) -func (w *Wireguard) addRule(rulePriority, firewallMark uint32, - family uint8, +func AddRule(rulePriority, firewallMark uint32, family uint8, + netlinker NetLinker, logger Logger, ) (cleanup func() error, err error) { rule := netlink.Rule{ Priority: &rulePriority, @@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32, Flags: netlink.FlagInvert, Action: netlink.ActionToTable, } - if err := w.netlink.RuleAdd(rule); err != nil { + if err := netlinker.RuleAdd(rule); err != nil { if strings.HasSuffix(err.Error(), "file exists") { - w.logger.Info("if you are using Kubernetes, this may fix the error below: " + + logger.Info("if you are using Kubernetes, this may fix the error below: " + "https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists") } return nil, fmt.Errorf("adding %s: %w", rule, err) } cleanup = func() error { - err := w.netlink.RuleDel(rule) + err := netlinker.RuleDel(rule) if err != nil { return fmt.Errorf("deleting rule %s: %w", rule, err) } diff --git a/internal/wireguard/rule_test.go b/internal/wireguard/rule_test.go index c6f02f36..8b8f4503 100644 --- a/internal/wireguard/rule_test.go +++ b/internal/wireguard/rule_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_Wireguard_addRule(t *testing.T) { +func Test_AddRule(t *testing.T) { t.Parallel() const rulePriority uint32 = 987 @@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) { ctrl := gomock.NewController(t) netLinker := NewMockNetLinker(ctrl) - wg := Wireguard{ - netlink: netLinker, - } netLinker.EXPECT().RuleAdd(testCase.expectedRule). Return(testCase.ruleAddErr) - cleanup, err := wg.addRule(rulePriority, firewallMark, family) + cleanup, err := AddRule(rulePriority, firewallMark, family, + netLinker, nil) if testCase.err != nil { require.Error(t, err) assert.Equal(t, testCase.err.Error(), err.Error()) diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index b6f06aad..db23161b 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -6,39 +6,33 @@ import ( "fmt" "net" + "github.com/qdm12/gluetun/internal/cleanup" "github.com/qdm12/gluetun/internal/netlink" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/wgctrl" ) var ( - ErrDetectKernel = errors.New("cannot detect Kernel support") - ErrCreateTun = errors.New("cannot create TUN device") - ErrAddLink = errors.New("cannot add Wireguard link") - ErrFindLink = errors.New("cannot find link") - ErrFindDevice = errors.New("cannot find Wireguard device") - ErrUAPISocketOpening = errors.New("cannot open UAPI socket") - ErrWgctrlOpen = errors.New("cannot open wgctrl") - ErrUAPIListen = errors.New("cannot listen on UAPI socket") - ErrAddAddress = errors.New("cannot add address to wireguard interface") - ErrConfigure = errors.New("cannot configure wireguard interface") - ErrDeviceInfo = errors.New("cannot get wireguard device information") - ErrIfaceUp = errors.New("cannot set the interface to UP") - ErrRouteAdd = errors.New("cannot add route for interface") - ErrDeviceWaited = errors.New("device waited for") - ErrKernelSupport = errors.New("kernel does not support Wireguard") - ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG") + errKernelSupport = errors.New("kernel does not support Wireguard") + errTunNameMismatch = errors.New("TUN device name is mismatching") + errDeviceWaited = errors.New("device waited for") ) +// Run runs the wireguard interface and waits until the context is done, then it cleans up the +// interface and returns any error that occurred during setup or waiting. It sends an error to +// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context +// is done. It sends a signal to ready when the setup is complete and the interface is ready to use. // See https://git.zx2c4.com/wireguard-go/tree/main.go func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { kernelSupported, err := w.netlink.IsWireguardSupported() if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err) + waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err) return } - userspaceBackend := defaultUserSpaceBackend() - setupFunction := setupUserSpaceCommon + setupFunction := setupUserSpace switch w.settings.Implementation { case "auto": //nolint:goconst if !kernelSupported { @@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< case "userspace": case "kernelspace": if !kernelSupported { - waitError <- fmt.Errorf("%w", ErrKernelSupport) + waitError <- fmt.Errorf("%w", errKernelSupport) return } setupFunction = setupKernelSpace - case "amneziawg": - userspaceBackend = amneziaUserSpaceBackend() default: panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation)) } + setup := func(ctx context.Context, cleanups *cleanup.Cleanups) ( + linkIndex uint32, waitAndCleanup func() error, err error, + ) { + return setupFunction(ctx, + w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger) + } + + Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger) +} + +func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}, + setup func(ctx context.Context, cleanups *cleanup.Cleanups) ( + linkIndex uint32, waitAndCleanup func() error, err error), + settings Settings, netlinker NetLinker, logger Logger, +) { client, err := wgctrl.New() if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) + waitError <- fmt.Errorf("opening wgctrl: %w", err) return } - var closers closers - closers.add("closing controller client", stepOne, client.Close) + var cleanups cleanup.Cleanups + cleanups.Add("closing controller client", 1, client.Close) - defer closers.cleanup(w.logger) + defer cleanups.Cleanup(logger) - linkIndex, waitAndCleanup, err := setupFunction(ctx, - w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend) + linkIndex, waitAndCleanup, err := setup(ctx, &cleanups) if err != nil { waitError <- err return } - err = w.addAddresses(linkIndex, w.settings.Addresses) + err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker) if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err) + waitError <- fmt.Errorf("adding addresses to interface: %w", err) return } - w.logger.Info("Connecting to " + w.settings.Endpoint.String()) - err = configureDevice(client, w.settings) + logger.Info("Connecting to " + settings.Endpoint.String()) + err = ConfigureDevice(client, settings) if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrConfigure, err) + waitError <- fmt.Errorf("configuring interface: %w", err) return } - err = w.netlink.LinkSetUp(linkIndex) + err = netlinker.LinkSetUp(linkIndex) if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) + waitError <- fmt.Errorf("setting the interface UP: %w", err) return } - closers.add("shutting down link", stepFour, func() error { - return w.netlink.LinkSetDown(linkIndex) + cleanups.Add("shutting down link", 4, func() error { + return netlinker.LinkSetDown(linkIndex) }) - err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark) + err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark, + netlinker, logger) if err != nil { - waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err) + waitError <- fmt.Errorf("adding routes for interface: %w", err) return } - if *w.settings.IPv6 { + if *settings.IPv6 { // requires net.ipv6.conf.all.disable_ipv6=0 - ruleCleanup6, err := w.addRule(w.settings.RulePriority, - w.settings.FirewallMark, netlink.FamilyV6) + ruleCleanup6, err := AddRule(settings.RulePriority, + settings.FirewallMark, netlink.FamilyV6, + netlinker, logger) if err != nil { waitError <- fmt.Errorf("adding IPv6 rule: %w", err) return } - closers.add("removing IPv6 rule", stepOne, ruleCleanup6) + cleanups.Add("removing IPv6 rule", 1, ruleCleanup6) } - ruleCleanup, err := w.addRule(w.settings.RulePriority, - w.settings.FirewallMark, netlink.FamilyV4) + ruleCleanup, err := AddRule(settings.RulePriority, + settings.FirewallMark, netlink.FamilyV4, + netlinker, logger) if err != nil { waitError <- fmt.Errorf("adding IPv4 rule: %w", err) return } - closers.add("removing IPv4 rule", stepOne, ruleCleanup) - w.logger.Info("Wireguard setup is complete. " + - "Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " + - "Typically i/o timeout errors indicate the Wireguard connection is not working.") + cleanups.Add("removing IPv4 rule", 1, ruleCleanup) ready <- struct{}{} waitError <- waitAndCleanup() } -type waitAndCleanupFunc func() error - func setupKernelSpace(ctx context.Context, interfaceName string, netLinker NetLinker, mtu uint32, - closers *closers, logger Logger, _ Settings, _ userSpaceBackend) ( - linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error, + cleanups *cleanup.Cleanups, logger Logger) ( + linkIndex uint32, waitAndCleanup func() error, err error, ) { links, err := netLinker.LinkList() if err != nil { @@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context, } linkIndex, err = netLinker.LinkAdd(link) if err != nil { - return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err) + return 0, nil, fmt.Errorf("adding link: %w", err) } - closers.add("deleting link", stepFive, func() error { + cleanups.Add("deleting link", 5, func() error { return netLinker.LinkDel(linkIndex) }) waitAndCleanup = func() error { <-ctx.Done() - closers.cleanup(logger) + cleanups.Cleanup(logger) return ctx.Err() } return linkIndex, waitAndCleanup, nil } -func setupUserSpaceCommon(ctx context.Context, +func setupUserSpace(ctx context.Context, interfaceName string, netLinker NetLinker, mtu uint32, - closers *closers, logger Logger, - settings Settings, b userSpaceBackend, -) ( - linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error, + cleanups *cleanup.Cleanups, logger Logger) ( + linkIndex uint32, waitAndCleanup func() error, err error, ) { - tun, err := b.createTun(interfaceName, int(mtu)) + tun, err := tun.CreateTUN(interfaceName, int(mtu)) if err != nil { - return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) + return 0, nil, fmt.Errorf("creating TUN device: %w", err) } - closers.add("closing TUN device", stepSeven, tun.Close) + cleanups.Add("closing TUN device", 7, tun.Close) tunName, err := tun.Name() if err != nil { - return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) + return 0, nil, fmt.Errorf("getting created TUN device name: %w", err) } else if tunName != interfaceName { - return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", - ErrCreateTun, interfaceName, tunName) + return 0, nil, fmt.Errorf("%w: expected %q and got %q", + errTunNameMismatch, interfaceName, tunName) } link, err := netLinker.LinkByName(interfaceName) if err != nil { - return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) + return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err) } - closers.add("deleting link", stepFive, func() error { + cleanups.Add("deleting link", 5, func() error { return netLinker.LinkDel(link.Index) }) - bind := b.createBind() + bind := conn.NewDefaultBind() - closers.add("closing bind", stepSeven, bind.Close) + cleanups.Add("closing bind", 7, bind.Close) - device := b.createDevice(tun, bind, logger) + deviceLogger := makeDeviceLogger(logger) + device := device.NewDevice(tun, bind, deviceLogger) - closers.add("closing Wireguard device", stepSix, func() error { + cleanups.Add("closing Wireguard device", 6, func() error { device.Close() return nil }) - uapiFile, err := uapiOpen(interfaceName) + uapiFile, err := UAPIOpen(interfaceName) if err != nil { - return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) + return 0, nil, fmt.Errorf("opening UAPI socket: %w", err) } - closers.add("closing UAPI file", stepThree, uapiFile.Close) + cleanups.Add("closing UAPI file", 3, uapiFile.Close) - uapiListener, err := uapiListen(interfaceName, uapiFile) + uapiListener, err := UAPIListen(interfaceName, uapiFile) if err != nil { - return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) + return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err) } - closers.add("closing UAPI listener", stepTwo, uapiListener.Close) - - if b.preStart != nil { - err = b.preStart(device, settings) - if err != nil { - return 0, nil, err - } - } + cleanups.Add("closing UAPI listener", 2, uapiListener.Close) // acceptAndHandle exits when uapiListener is closed uapiAcceptErrorCh := make(chan error) @@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context, case err = <-uapiAcceptErrorCh: close(uapiAcceptErrorCh) case <-device.Wait(): - err = ErrDeviceWaited + err = errDeviceWaited } - closers.cleanup(logger) + cleanups.Cleanup(logger) <-uapiAcceptErrorCh // wait for acceptAndHandle to exit @@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context, return link.Index, waitAndCleanup, nil } -func acceptAndHandle(uapi net.Listener, device userspaceDevice, +func acceptAndHandle(uapi net.Listener, device *device.Device, uapiAcceptErrorCh chan<- error, ) { for { // stopped by uapiFile.Close() diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index c9c10216..66eee765 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -46,11 +46,8 @@ type Settings struct { // It defaults to false if left unset. IPv6 *bool // Implementation is the implementation to use. - // It can be auto, kernelspace, userspace or amneziawg, - // and defaults to auto. + // It can be auto, kernelspace or userspace, and defaults to auto. Implementation string - // AmneziaWG settings are extra obfuscation parameters - AmneziaWG AmneziaSettings } func (s *Settings) SetDefaults() { @@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) { } switch s.Implementation { - case "auto", "kernelspace", "userspace", "amneziawg": + case "auto", "kernelspace", "userspace": default: return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation) } diff --git a/internal/wireguard/userspaces.go b/internal/wireguard/userspaces.go deleted file mode 100644 index f5220f20..00000000 --- a/internal/wireguard/userspaces.go +++ /dev/null @@ -1,58 +0,0 @@ -package wireguard - -import ( - amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn" - amneziadevice "github.com/amnezia-vpn/amneziawg-go/device" - amneziatun "github.com/amnezia-vpn/amneziawg-go/tun" - wgconn "golang.zx2c4.com/wireguard/conn" - wgdevice "golang.zx2c4.com/wireguard/device" - wgtun "golang.zx2c4.com/wireguard/tun" -) - -func defaultUserSpaceBackend() userSpaceBackend { - return userSpaceBackend{ - createTun: func(name string, mtu int) (tunDevice, error) { - return wgtun.CreateTUN(name, mtu) - }, - createBind: func() bind { - return wgconn.NewDefaultBind() - }, - createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice { - wgtun, _ := td.(wgtun.Device) - wgBind, _ := b.(wgconn.Bind) - wgLogger := wgdevice.Logger{ - Verbosef: logger.Debugf, - Errorf: logger.Errorf, - } - device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger) - return device - }, - preStart: nil, - } -} - -func amneziaUserSpaceBackend() userSpaceBackend { - return userSpaceBackend{ - createTun: func(name string, mtu int) (tunDevice, error) { - return amneziatun.CreateTUN(name, mtu) - }, - createBind: func() bind { - return amneziaconn.NewDefaultBind() - }, - createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice { - wgamneziaTun, _ := td.(amneziatun.Device) - wgamneziaBind, _ := b.(amneziaconn.Bind) - wgamneziaLogger := amneziadevice.Logger{ - Verbosef: logger.Debugf, - Errorf: logger.Errorf, - } - device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger) - return device - }, - preStart: func(ud userspaceDevice, s Settings) error { - uapiConfig := s.AmneziaWG.uapiConfig() - err := ud.IpcSet(uapiConfig) - return err - }, - } -} diff --git a/internal/wireguard/wireguard_linux.go b/internal/wireguard/wireguard_linux.go index 6c066eb5..ba2085ce 100644 --- a/internal/wireguard/wireguard_linux.go +++ b/internal/wireguard/wireguard_linux.go @@ -7,10 +7,10 @@ import ( "golang.zx2c4.com/wireguard/ipc" ) -func uapiOpen(name string) (*os.File, error) { +func UAPIOpen(name string) (*os.File, error) { return ipc.UAPIOpen(name) } -func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) { +func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) { return ipc.UAPIListen(interfaceName, uapiFile) } diff --git a/internal/wireguard/wireguard_unspecified.go b/internal/wireguard/wireguard_unspecified.go index e7619fd6..171cf38e 100644 --- a/internal/wireguard/wireguard_unspecified.go +++ b/internal/wireguard/wireguard_unspecified.go @@ -7,10 +7,10 @@ import ( "os" ) -func uapiOpen(name string) (*os.File, error) { +func UAPIOpen(name string) (*os.File, error) { panic("not implemented") } -func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) { +func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) { panic("not implemented") }