mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
chore!(amneziawg): refactor to be separate from wireguard
- amneziawg is now a VPN protocol and no longer a Wireguard implementation - Use it with VPN_TYPE=amneziawg - document AMNEZIAWG_* options in Dockerfile - document amneziawg support in readme - separate amneziawg settings and code from wireguard - re-use code from wireguard whenever possible
This commit is contained in:
@@ -60,6 +60,10 @@ linters:
|
||||
- linters:
|
||||
- lll
|
||||
source: "^// https://.+$"
|
||||
- linters:
|
||||
- mnd
|
||||
source: "^ cleanups\\.Add.+$"
|
||||
path: internal\/(wireguard|amneziawg)\/run\.go
|
||||
- linters:
|
||||
- err113
|
||||
- mnd
|
||||
|
||||
+30
@@ -112,6 +112,36 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
WIREGUARD_MTU= \
|
||||
WIREGUARD_IMPLEMENTATION=auto \
|
||||
# Amnezia
|
||||
AMNEZIAWG_ENDPOINT_IP= \
|
||||
AMNEZIAWG_ENDPOINT_PORT= \
|
||||
AMNEZIAWG_CONF_SECRETFILE=/run/secrets/wg0.conf \
|
||||
AMNEZIAWG_PRIVATE_KEY= \
|
||||
AMNEZIAWG_PRIVATE_KEY_SECRETFILE=/run/secrets/wireguard_private_key \
|
||||
AMNEZIAWG_PRESHARED_KEY= \
|
||||
AMNEZIAWG_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
|
||||
AMNEZIAWG_PUBLIC_KEY= \
|
||||
AMNEZIAWG_ALLOWED_IPS= \
|
||||
AMNEZIAWG_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||
AMNEZIAWG_ADDRESSES= \
|
||||
AMNEZIAWG_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
AMNEZIAWG_MTU= \
|
||||
AMNEZIAWG_JC=0 \
|
||||
AMNEZIAWG_JMIN=0 \
|
||||
AMNEZIAWG_JMAX=0 \
|
||||
AMNEZIAWG_S1=0 \
|
||||
AMNEZIAWG_S2=0 \
|
||||
AMNEZIAWG_S3=0 \
|
||||
AMNEZIAWG_S4=0 \
|
||||
AMNEZIAWG_H1= \
|
||||
AMNEZIAWG_H2= \
|
||||
AMNEZIAWG_H3= \
|
||||
AMNEZIAWG_H4= \
|
||||
AMNEZIAWG_I1= \
|
||||
AMNEZIAWG_I2= \
|
||||
AMNEZIAWG_I3= \
|
||||
AMNEZIAWG_I4= \
|
||||
AMNEZIAWG_I5= \
|
||||
# Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg)
|
||||
AMNEZIAWG_JC=0 \
|
||||
AMNEZIAWG_JMIN=0 \
|
||||
|
||||
@@ -67,6 +67,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
||||
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
|
||||
- Supports AmneziaWG only with the custom provider for now
|
||||
- DNS over TLS baked in with service provider(s) of your choice
|
||||
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
|
||||
- Choose the vpn network protocol, `udp` or `tcp`
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package amneziawg
|
||||
|
||||
type Amneziawg struct {
|
||||
logger Logger
|
||||
settings Settings
|
||||
netlink NetLinker
|
||||
}
|
||||
|
||||
func New(settings Settings, netlink NetLinker,
|
||||
logger Logger,
|
||||
) (a *Amneziawg, err error) {
|
||||
settings.SetDefaults()
|
||||
if err := settings.Check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Amneziawg{
|
||||
logger: logger,
|
||||
settings: settings,
|
||||
netlink: netlink,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
logger := NewMockLogger(nil)
|
||||
netLinker := NewMockNetLinker(nil)
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
amneziawg *Amneziawg
|
||||
err error
|
||||
}{
|
||||
"bad_settings": {
|
||||
settings: Settings{
|
||||
Wireguard: wireguard.Settings{
|
||||
PrivateKey: "",
|
||||
},
|
||||
},
|
||||
err: wireguard.ErrPrivateKeyMissing,
|
||||
},
|
||||
"minimal valid settings": {
|
||||
settings: Settings{
|
||||
Wireguard: wireguard.Settings{
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
||||
},
|
||||
FirewallMark: 100,
|
||||
},
|
||||
},
|
||||
amneziawg: &Amneziawg{
|
||||
logger: logger,
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
Wireguard: wireguard.Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
||||
},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
FirewallMark: 100,
|
||||
MTU: device.DefaultMTU,
|
||||
IPv6: ptrTo(false),
|
||||
Implementation: "auto",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguard, err := New(testCase.settings, netLinker, logger)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.amneziawg, wireguard)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package amneziawg
|
||||
|
||||
func ptrTo[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package amneziawg
|
||||
|
||||
//go:generate mockgen -destination=log_mock_test.go -package amneziawg . Logger
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: Logger)
|
||||
|
||||
// Package amneziawg is a generated GoMock package.
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Errorf mocks base method.
|
||||
func (m *MockLogger) Errorf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Errorf", varargs...)
|
||||
}
|
||||
|
||||
// Errorf indicates an expected call of Errorf.
|
||||
func (mr *MockLoggerMockRecorder) Errorf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=netlinker_mock_test.go -package amneziawg . NetLinker
|
||||
|
||||
type NetLinker interface {
|
||||
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||
Router
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||
RouteAdd(route netlink.Route) error
|
||||
}
|
||||
|
||||
type Ruler interface {
|
||||
RuleAdd(rule netlink.Rule) error
|
||||
RuleDel(rule netlink.Rule) error
|
||||
}
|
||||
|
||||
type Linker interface {
|
||||
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||
LinkList() (links []netlink.Link, err error)
|
||||
LinkByName(name string) (link netlink.Link, err error)
|
||||
LinkSetUp(linkIndex uint32) error
|
||||
LinkSetDown(linkIndex uint32) error
|
||||
LinkDel(linkIndex uint32) error
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: NetLinker)
|
||||
|
||||
// Package amneziawg is a generated GoMock package.
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
netlink "github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
// MockNetLinker is a mock of NetLinker interface.
|
||||
type MockNetLinker struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockNetLinkerMockRecorder
|
||||
}
|
||||
|
||||
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
|
||||
type MockNetLinkerMockRecorder struct {
|
||||
mock *MockNetLinker
|
||||
}
|
||||
|
||||
// NewMockNetLinker creates a new mock instance.
|
||||
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
|
||||
mock := &MockNetLinker{ctrl: ctrl}
|
||||
mock.recorder = &MockNetLinkerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddrReplace mocks base method.
|
||||
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddrReplace indicates an expected call of AddrReplace.
|
||||
func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrReplace", reflect.TypeOf((*MockNetLinker)(nil).AddrReplace), arg0, arg1)
|
||||
}
|
||||
|
||||
// IsWireguardSupported mocks base method.
|
||||
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
||||
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
|
||||
}
|
||||
|
||||
// LinkAdd mocks base method.
|
||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkAdd indicates an expected call of LinkAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0)
|
||||
}
|
||||
|
||||
// LinkByName mocks base method.
|
||||
func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkByName", arg0)
|
||||
ret0, _ := ret[0].(netlink.Link)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkByName indicates an expected call of LinkByName.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0)
|
||||
}
|
||||
|
||||
// LinkDel mocks base method.
|
||||
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkDel indicates an expected call of LinkDel.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0)
|
||||
}
|
||||
|
||||
// LinkList mocks base method.
|
||||
func (m *MockNetLinker) LinkList() ([]netlink.Link, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkList")
|
||||
ret0, _ := ret[0].([]netlink.Link)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkList indicates an expected call of LinkList.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList))
|
||||
}
|
||||
|
||||
// LinkSetDown mocks base method.
|
||||
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkSetDown indicates an expected call of LinkSetDown.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0)
|
||||
}
|
||||
|
||||
// LinkSetUp mocks base method.
|
||||
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||
func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLinker)(nil).LinkSetUp), arg0)
|
||||
}
|
||||
|
||||
// RouteAdd mocks base method.
|
||||
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RouteAdd indicates an expected call of RouteAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
|
||||
}
|
||||
|
||||
// RouteList mocks base method.
|
||||
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||
ret0, _ := ret[0].([]netlink.Route)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RouteList indicates an expected call of RouteList.
|
||||
func (mr *MockNetLinkerMockRecorder) RouteList(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLinker)(nil).RouteList), arg0)
|
||||
}
|
||||
|
||||
// RuleAdd mocks base method.
|
||||
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleAdd indicates an expected call of RuleAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
|
||||
}
|
||||
|
||||
// RuleDel mocks base method.
|
||||
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleDel indicates an expected call of RuleDel.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
|
||||
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"github.com/qdm12/gluetun/internal/cleanup"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
var (
|
||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||
errDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
|
||||
// See https://github.com/amnezia-vpn/amneziawg-go/blob/master/main.go
|
||||
func (a *Amneziawg) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
return setupUserspace(ctx, a.settings.Wireguard.InterfaceName,
|
||||
a.netlink, a.settings.Wireguard.MTU, cleanups, a.logger, a.settings)
|
||||
}
|
||||
|
||||
wireguard.Run(ctx, waitError, ready, setup, a.settings.Wireguard, a.netlink, a.logger)
|
||||
}
|
||||
|
||||
func setupUserspace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
cleanups *cleanup.Cleanups, logger Logger,
|
||||
settings Settings,
|
||||
) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
tun, err := amneziatun.CreateTUN(interfaceName, int(mtu))
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
|
||||
}
|
||||
|
||||
cleanups.Add("closing TUN device", 7, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||
} else if tunName != interfaceName {
|
||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||
errTunNameMismatch, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
|
||||
}
|
||||
cleanups.Add("deleting link", 5, func() error {
|
||||
return netLinker.LinkDel(link.Index)
|
||||
})
|
||||
|
||||
bind := amneziaconn.NewDefaultBind()
|
||||
cleanups.Add("closing bind", 7, bind.Close)
|
||||
|
||||
deviceLogger := amneziadevice.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
device := amneziadevice.NewDevice(tun, bind, &deviceLogger)
|
||||
|
||||
cleanups.Add("closing Wireguard device", 6, func() error {
|
||||
device.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := wireguard.UAPIOpen(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
|
||||
}
|
||||
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
|
||||
|
||||
uapiListener, err := wireguard.UAPIListen(interfaceName, uapiFile)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
|
||||
}
|
||||
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
|
||||
|
||||
uapiConfig := settings.uapiConfig()
|
||||
err = device.IpcSet(uapiConfig)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("setting amneziawg uapi config: %w", err)
|
||||
}
|
||||
|
||||
// acceptAndHandle exits when uapiListener is closed
|
||||
uapiAcceptErrorCh := make(chan error)
|
||||
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
|
||||
waitAndCleanup = func() error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = errDeviceWaited
|
||||
}
|
||||
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return link.Index, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device *amneziadevice.Device,
|
||||
uapiAcceptErrorCh chan<- error,
|
||||
) {
|
||||
for { // stopped by uapiFile.Close()
|
||||
conn, err := uapi.Accept()
|
||||
if err != nil {
|
||||
uapiAcceptErrorCh <- err
|
||||
return
|
||||
}
|
||||
go device.IpcHandle(conn)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
package wireguard
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
type AmneziaSettings struct {
|
||||
type Settings struct {
|
||||
Wireguard wireguard.Settings
|
||||
JunkPacketCount uint16
|
||||
JunkPacketMin uint16
|
||||
JunkPacketMax uint16
|
||||
@@ -24,7 +27,7 @@ type AmneziaSettings struct {
|
||||
InitPacketI5 string
|
||||
}
|
||||
|
||||
func (s AmneziaSettings) uapiConfig() string {
|
||||
func (s Settings) uapiConfig() string {
|
||||
uintFields := map[string]uint16{
|
||||
"jc": s.JunkPacketCount,
|
||||
"jmin": s.JunkPacketMin,
|
||||
@@ -56,3 +59,11 @@ func (s AmneziaSettings) uapiConfig() string {
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
s.Wireguard.SetDefaults()
|
||||
}
|
||||
|
||||
func (s *Settings) Check() error {
|
||||
return s.Wireguard.Check()
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package cleanup
|
||||
|
||||
import "sort"
|
||||
|
||||
type Cleanups []cleanup
|
||||
|
||||
type cleanup struct {
|
||||
operation string
|
||||
orderIndex uint
|
||||
cleanup func() error
|
||||
done bool
|
||||
}
|
||||
|
||||
// Add adds a cleanup function to the list of cleanups, with a description of the
|
||||
// operation being cleaned up, and an order index that determines the order in which
|
||||
// the cleanup functions are run. The lower the order index, the earlier the cleanup
|
||||
// function is run.
|
||||
func (c *Cleanups) Add(operation string, orderIndex uint,
|
||||
cleanupFunc func() error,
|
||||
) {
|
||||
closer := cleanup{
|
||||
operation: operation,
|
||||
orderIndex: orderIndex,
|
||||
cleanup: cleanupFunc,
|
||||
}
|
||||
*c = append(*c, closer)
|
||||
}
|
||||
|
||||
// Cleanup runs the cleanup functions in the order of their orderIndex,
|
||||
// and logs any error that occurs during cleanup.
|
||||
// It can also be re-called in case a cleanup fails, and already cleaned up
|
||||
// functions will not be re-run.
|
||||
func (c *Cleanups) Cleanup(logger Logger) {
|
||||
closers := *c
|
||||
|
||||
sort.Slice(closers, func(i, j int) bool {
|
||||
return closers[i].orderIndex < closers[j].orderIndex
|
||||
})
|
||||
|
||||
for i, closer := range closers {
|
||||
if closer.done {
|
||||
continue
|
||||
}
|
||||
closers[i].done = true
|
||||
logger.Debug(closer.operation + "...")
|
||||
err := closer.cleanup()
|
||||
if err != nil {
|
||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Cleanups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||
var (
|
||||
AErr error
|
||||
BErr = errors.New("B failed")
|
||||
CErr = errors.New("C failed")
|
||||
)
|
||||
|
||||
var cleanups Cleanups
|
||||
cleanups.Add("cleaning up A", 5, func() error {
|
||||
ACloseCalled = true
|
||||
return AErr
|
||||
})
|
||||
|
||||
cleanups.Add("cleaning up B", 3, func() error {
|
||||
BCloseCalled = true
|
||||
return BErr
|
||||
})
|
||||
|
||||
cleanups.Add("cleaning up C", 2, func() error {
|
||||
CCloseCalled = true
|
||||
return CErr
|
||||
})
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
prevCall := logger.EXPECT().Debug("cleaning up C...")
|
||||
prevCall = logger.EXPECT().Error("failed cleaning up C: C failed").After(prevCall)
|
||||
prevCall = logger.EXPECT().Debug("cleaning up B...").After(prevCall)
|
||||
prevCall = logger.EXPECT().Error("failed cleaning up B: B failed").After(prevCall)
|
||||
logger.EXPECT().Debug("cleaning up A...").After(prevCall)
|
||||
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
cleanups.Cleanup(logger) // run twice should not close already closed
|
||||
|
||||
for _, cleanup := range cleanups {
|
||||
assert.True(t, cleanup.done)
|
||||
}
|
||||
|
||||
assert.True(t, ACloseCalled)
|
||||
assert.True(t, BCloseCalled)
|
||||
assert.True(t, CCloseCalled)
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package cleanup
|
||||
|
||||
type Logger interface {
|
||||
Debug(string)
|
||||
Error(string)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package cleanup
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||
@@ -0,0 +1,58 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/cleanup (interfaces: Logger)
|
||||
|
||||
// Package cleanup is a generated GoMock package.
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
type AmneziaWg struct {
|
||||
// Wireguard contains the configuration for Wireguard, given
|
||||
// AmneziaWg is based on Wireguard
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
JunkPacketCount *uint16 `json:"junk_packet_count"`
|
||||
JunkPacketMin *uint16 `json:"junk_packet_min"`
|
||||
JunkPacketMax *uint16 `json:"junk_packet_max"`
|
||||
@@ -30,15 +33,21 @@ type AmneziaWg struct {
|
||||
InitPacketI5 *string `json:"init_packet_i5"`
|
||||
}
|
||||
|
||||
func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
||||
func (a *AmneziaWg) read(r *reader.Reader) (err error) {
|
||||
const amneziawg = true
|
||||
err = a.Wireguard.read(r, amneziawg)
|
||||
if err != nil {
|
||||
return err // do not wrap this error
|
||||
}
|
||||
|
||||
uint16Fields := map[string]**uint16{
|
||||
"AMNEZIAWG_JC": &s.JunkPacketCount,
|
||||
"AMNEZIAWG_JMIN": &s.JunkPacketMin,
|
||||
"AMNEZIAWG_JMAX": &s.JunkPacketMax,
|
||||
"AMNEZIAWG_S1": &s.PaddingS1,
|
||||
"AMNEZIAWG_S2": &s.PaddingS2,
|
||||
"AMNEZIAWG_S3": &s.PaddingS3,
|
||||
"AMNEZIAWG_S4": &s.PaddingS4,
|
||||
"AMNEZIAWG_JC": &a.JunkPacketCount,
|
||||
"AMNEZIAWG_JMIN": &a.JunkPacketMin,
|
||||
"AMNEZIAWG_JMAX": &a.JunkPacketMax,
|
||||
"AMNEZIAWG_S1": &a.PaddingS1,
|
||||
"AMNEZIAWG_S2": &a.PaddingS2,
|
||||
"AMNEZIAWG_S3": &a.PaddingS3,
|
||||
"AMNEZIAWG_S4": &a.PaddingS4,
|
||||
}
|
||||
for key, dst := range uint16Fields {
|
||||
*dst, err = r.Uint16Ptr(key)
|
||||
@@ -47,15 +56,15 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
||||
}
|
||||
}
|
||||
stringFields := map[string]**string{
|
||||
"AMNEZIAWG_H1": &s.HeaderH1,
|
||||
"AMNEZIAWG_H2": &s.HeaderH2,
|
||||
"AMNEZIAWG_H3": &s.HeaderH3,
|
||||
"AMNEZIAWG_H4": &s.HeaderH4,
|
||||
"AMNEZIAWG_I1": &s.InitPacketI1,
|
||||
"AMNEZIAWG_I2": &s.InitPacketI2,
|
||||
"AMNEZIAWG_I3": &s.InitPacketI3,
|
||||
"AMNEZIAWG_I4": &s.InitPacketI4,
|
||||
"AMNEZIAWG_I5": &s.InitPacketI5,
|
||||
"AMNEZIAWG_H1": &a.HeaderH1,
|
||||
"AMNEZIAWG_H2": &a.HeaderH2,
|
||||
"AMNEZIAWG_H3": &a.HeaderH3,
|
||||
"AMNEZIAWG_H4": &a.HeaderH4,
|
||||
"AMNEZIAWG_I1": &a.InitPacketI1,
|
||||
"AMNEZIAWG_I2": &a.InitPacketI2,
|
||||
"AMNEZIAWG_I3": &a.InitPacketI3,
|
||||
"AMNEZIAWG_I4": &a.InitPacketI4,
|
||||
"AMNEZIAWG_I5": &a.InitPacketI5,
|
||||
}
|
||||
opt := reader.ForceLowercase(false)
|
||||
for key, dst := range stringFields {
|
||||
@@ -64,80 +73,84 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s AmneziaWg) copy() (copied AmneziaWg) {
|
||||
func (a AmneziaWg) copy() (copied AmneziaWg) {
|
||||
return AmneziaWg{
|
||||
JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount),
|
||||
JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin),
|
||||
JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax),
|
||||
PaddingS1: gosettings.CopyPointer(s.PaddingS1),
|
||||
PaddingS2: gosettings.CopyPointer(s.PaddingS2),
|
||||
PaddingS3: gosettings.CopyPointer(s.PaddingS3),
|
||||
PaddingS4: gosettings.CopyPointer(s.PaddingS4),
|
||||
HeaderH1: gosettings.CopyPointer(s.HeaderH1),
|
||||
HeaderH2: gosettings.CopyPointer(s.HeaderH2),
|
||||
HeaderH3: gosettings.CopyPointer(s.HeaderH3),
|
||||
HeaderH4: gosettings.CopyPointer(s.HeaderH4),
|
||||
InitPacketI1: gosettings.CopyPointer(s.InitPacketI1),
|
||||
InitPacketI2: gosettings.CopyPointer(s.InitPacketI2),
|
||||
InitPacketI3: gosettings.CopyPointer(s.InitPacketI3),
|
||||
InitPacketI4: gosettings.CopyPointer(s.InitPacketI4),
|
||||
InitPacketI5: gosettings.CopyPointer(s.InitPacketI5),
|
||||
Wireguard: a.Wireguard.copy(),
|
||||
JunkPacketCount: gosettings.CopyPointer(a.JunkPacketCount),
|
||||
JunkPacketMin: gosettings.CopyPointer(a.JunkPacketMin),
|
||||
JunkPacketMax: gosettings.CopyPointer(a.JunkPacketMax),
|
||||
PaddingS1: gosettings.CopyPointer(a.PaddingS1),
|
||||
PaddingS2: gosettings.CopyPointer(a.PaddingS2),
|
||||
PaddingS3: gosettings.CopyPointer(a.PaddingS3),
|
||||
PaddingS4: gosettings.CopyPointer(a.PaddingS4),
|
||||
HeaderH1: gosettings.CopyPointer(a.HeaderH1),
|
||||
HeaderH2: gosettings.CopyPointer(a.HeaderH2),
|
||||
HeaderH3: gosettings.CopyPointer(a.HeaderH3),
|
||||
HeaderH4: gosettings.CopyPointer(a.HeaderH4),
|
||||
InitPacketI1: gosettings.CopyPointer(a.InitPacketI1),
|
||||
InitPacketI2: gosettings.CopyPointer(a.InitPacketI2),
|
||||
InitPacketI3: gosettings.CopyPointer(a.InitPacketI3),
|
||||
InitPacketI4: gosettings.CopyPointer(a.InitPacketI4),
|
||||
InitPacketI5: gosettings.CopyPointer(a.InitPacketI5),
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:dupl
|
||||
func (s *AmneziaWg) overrideWith(other AmneziaWg) {
|
||||
s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount)
|
||||
s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin)
|
||||
s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax)
|
||||
s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1)
|
||||
s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2)
|
||||
s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3)
|
||||
s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4)
|
||||
s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1)
|
||||
s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2)
|
||||
s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3)
|
||||
s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4)
|
||||
s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1)
|
||||
s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2)
|
||||
s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3)
|
||||
s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4)
|
||||
s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5)
|
||||
func (a *AmneziaWg) overrideWith(other AmneziaWg) {
|
||||
a.Wireguard.overrideWith(other.Wireguard)
|
||||
a.JunkPacketCount = gosettings.OverrideWithPointer(a.JunkPacketCount, other.JunkPacketCount)
|
||||
a.JunkPacketMin = gosettings.OverrideWithPointer(a.JunkPacketMin, other.JunkPacketMin)
|
||||
a.JunkPacketMax = gosettings.OverrideWithPointer(a.JunkPacketMax, other.JunkPacketMax)
|
||||
a.PaddingS1 = gosettings.OverrideWithPointer(a.PaddingS1, other.PaddingS1)
|
||||
a.PaddingS2 = gosettings.OverrideWithPointer(a.PaddingS2, other.PaddingS2)
|
||||
a.PaddingS3 = gosettings.OverrideWithPointer(a.PaddingS3, other.PaddingS3)
|
||||
a.PaddingS4 = gosettings.OverrideWithPointer(a.PaddingS4, other.PaddingS4)
|
||||
a.HeaderH1 = gosettings.OverrideWithPointer(a.HeaderH1, other.HeaderH1)
|
||||
a.HeaderH2 = gosettings.OverrideWithPointer(a.HeaderH2, other.HeaderH2)
|
||||
a.HeaderH3 = gosettings.OverrideWithPointer(a.HeaderH3, other.HeaderH3)
|
||||
a.HeaderH4 = gosettings.OverrideWithPointer(a.HeaderH4, other.HeaderH4)
|
||||
a.InitPacketI1 = gosettings.OverrideWithPointer(a.InitPacketI1, other.InitPacketI1)
|
||||
a.InitPacketI2 = gosettings.OverrideWithPointer(a.InitPacketI2, other.InitPacketI2)
|
||||
a.InitPacketI3 = gosettings.OverrideWithPointer(a.InitPacketI3, other.InitPacketI3)
|
||||
a.InitPacketI4 = gosettings.OverrideWithPointer(a.InitPacketI4, other.InitPacketI4)
|
||||
a.InitPacketI5 = gosettings.OverrideWithPointer(a.InitPacketI5, other.InitPacketI5)
|
||||
}
|
||||
|
||||
func (s *AmneziaWg) setDefaults() {
|
||||
s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0)
|
||||
s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0)
|
||||
s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0)
|
||||
s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0)
|
||||
s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0)
|
||||
s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0)
|
||||
s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0)
|
||||
s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "")
|
||||
s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "")
|
||||
s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "")
|
||||
s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "")
|
||||
s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "")
|
||||
s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "")
|
||||
s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "")
|
||||
s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "")
|
||||
s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "")
|
||||
func (a *AmneziaWg) setDefaults(vpnProvider string) {
|
||||
a.Wireguard.setDefaults(vpnProvider)
|
||||
a.Wireguard.Implementation = "userspace" // unused except in logs
|
||||
a.JunkPacketCount = gosettings.DefaultPointer(a.JunkPacketCount, 0)
|
||||
a.JunkPacketMin = gosettings.DefaultPointer(a.JunkPacketMin, 0)
|
||||
a.JunkPacketMax = gosettings.DefaultPointer(a.JunkPacketMax, 0)
|
||||
a.PaddingS1 = gosettings.DefaultPointer(a.PaddingS1, 0)
|
||||
a.PaddingS2 = gosettings.DefaultPointer(a.PaddingS2, 0)
|
||||
a.PaddingS3 = gosettings.DefaultPointer(a.PaddingS3, 0)
|
||||
a.PaddingS4 = gosettings.DefaultPointer(a.PaddingS4, 0)
|
||||
a.HeaderH1 = gosettings.DefaultPointer(a.HeaderH1, "")
|
||||
a.HeaderH2 = gosettings.DefaultPointer(a.HeaderH2, "")
|
||||
a.HeaderH3 = gosettings.DefaultPointer(a.HeaderH3, "")
|
||||
a.HeaderH4 = gosettings.DefaultPointer(a.HeaderH4, "")
|
||||
a.InitPacketI1 = gosettings.DefaultPointer(a.InitPacketI1, "")
|
||||
a.InitPacketI2 = gosettings.DefaultPointer(a.InitPacketI2, "")
|
||||
a.InitPacketI3 = gosettings.DefaultPointer(a.InitPacketI3, "")
|
||||
a.InitPacketI4 = gosettings.DefaultPointer(a.InitPacketI4, "")
|
||||
a.InitPacketI5 = gosettings.DefaultPointer(a.InitPacketI5, "")
|
||||
}
|
||||
|
||||
func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Amneziawg parameters:")
|
||||
func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("AmneziaWG settings:")
|
||||
node.AppendNode(a.Wireguard.toLinesNode())
|
||||
|
||||
uintFields := []struct {
|
||||
key string
|
||||
val *uint16
|
||||
}{
|
||||
{"jc", s.JunkPacketCount},
|
||||
{"jmin", s.JunkPacketMin},
|
||||
{"jmax", s.JunkPacketMax},
|
||||
{"s1", s.PaddingS1},
|
||||
{"s2", s.PaddingS2},
|
||||
{"s3", s.PaddingS3},
|
||||
{"s4", s.PaddingS4},
|
||||
{"JC", a.JunkPacketCount},
|
||||
{"JMIN", a.JunkPacketMin},
|
||||
{"JMAX", a.JunkPacketMax},
|
||||
{"S1", a.PaddingS1},
|
||||
{"S2", a.PaddingS2},
|
||||
{"S3", a.PaddingS3},
|
||||
{"S4", a.PaddingS4},
|
||||
}
|
||||
for _, f := range uintFields {
|
||||
node.Appendf("%s: %d", f.key, *f.val)
|
||||
@@ -147,15 +160,15 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
key string
|
||||
val *string
|
||||
}{
|
||||
{"h1", s.HeaderH1},
|
||||
{"h2", s.HeaderH2},
|
||||
{"h3", s.HeaderH3},
|
||||
{"h4", s.HeaderH4},
|
||||
{"i1", s.InitPacketI1},
|
||||
{"i2", s.InitPacketI2},
|
||||
{"i3", s.InitPacketI3},
|
||||
{"i4", s.InitPacketI4},
|
||||
{"i5", s.InitPacketI5},
|
||||
{"H1", a.HeaderH1},
|
||||
{"H2", a.HeaderH2},
|
||||
{"H3", a.HeaderH3},
|
||||
{"H4", a.HeaderH4},
|
||||
{"I1", a.InitPacketI1},
|
||||
{"I2", a.InitPacketI2},
|
||||
{"I3", a.InitPacketI3},
|
||||
{"I4", a.InitPacketI4},
|
||||
{"I5", a.InitPacketI5},
|
||||
}
|
||||
for _, f := range stringFields {
|
||||
node.Appendf("%s: %s", f.key, *f.val)
|
||||
@@ -165,33 +178,40 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
|
||||
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
||||
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
||||
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
||||
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
||||
)
|
||||
|
||||
func (s AmneziaWg) validate() error {
|
||||
if *s.JunkPacketCount == 0 {
|
||||
if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 {
|
||||
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
const amneziaWG = true
|
||||
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard settings: %w", err)
|
||||
}
|
||||
|
||||
if *a.JunkPacketCount == 0 {
|
||||
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
|
||||
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
} else {
|
||||
if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 {
|
||||
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
|
||||
} else if *s.JunkPacketMin > *s.JunkPacketMax {
|
||||
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
||||
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
||||
ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax)
|
||||
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
}
|
||||
|
||||
nameToHeaderRange := map[string]string{
|
||||
"h1": *s.HeaderH1,
|
||||
"h2": *s.HeaderH2,
|
||||
"h3": *s.HeaderH3,
|
||||
"h4": *s.HeaderH4,
|
||||
"h1": *a.HeaderH1,
|
||||
"h2": *a.HeaderH2,
|
||||
"h3": *a.HeaderH3,
|
||||
"h4": *a.HeaderH4,
|
||||
}
|
||||
for name, headerRange := range nameToHeaderRange {
|
||||
if headerRange == "" {
|
||||
|
||||
@@ -268,8 +268,6 @@ func (o *OpenVPN) copy() (copied OpenVPN) {
|
||||
// overrideWith overrides fields of the receiver
|
||||
// settings object with any field set in the other
|
||||
// settings.
|
||||
//
|
||||
//nolint:dupl
|
||||
func (o *OpenVPN) overrideWith(other OpenVPN) {
|
||||
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
|
||||
o.User = gosettings.OverrideWithPointer(o.User, other.User)
|
||||
|
||||
@@ -30,7 +30,10 @@ type Provider struct {
|
||||
func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) {
|
||||
// Validate Name
|
||||
var validNames []string
|
||||
if vpnType == vpn.OpenVPN {
|
||||
switch vpnType {
|
||||
case vpn.AmneziaWg:
|
||||
validNames = []string{providers.Custom}
|
||||
case vpn.OpenVPN:
|
||||
validNames = providers.AllWithCustom()
|
||||
validNames = append(validNames, "pia") // Retro-compatibility
|
||||
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
|
||||
@@ -38,7 +41,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
|
||||
validNames = validNames[:len(validNames)-1]
|
||||
sort.Strings(validNames)
|
||||
} else { // Wireguard
|
||||
case vpn.Wireguard:
|
||||
validNames = []string{
|
||||
providers.Airvpn,
|
||||
providers.Custom,
|
||||
@@ -52,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
}
|
||||
}
|
||||
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
||||
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||
|
||||
@@ -87,7 +87,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
||||
) (err error) {
|
||||
switch ss.VPN {
|
||||
case vpn.OpenVPN, vpn.Wireguard:
|
||||
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type VPN struct {
|
||||
// empty string in the internal state.
|
||||
Type string `json:"type"`
|
||||
Provider Provider `json:"provider"`
|
||||
AmneziaWg AmneziaWg `json:"amneziawg"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
PMTUD PMTUD `json:"pmtud"`
|
||||
@@ -29,10 +30,12 @@ type VPN struct {
|
||||
DownCommand *string `json:"down_command"`
|
||||
}
|
||||
|
||||
// Validate validates VPN settings, using the filter choices getter (aka servers data storage),
|
||||
// and if IPv6 is supported or not.
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) {
|
||||
// Validate Type
|
||||
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
|
||||
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
||||
}
|
||||
@@ -42,13 +45,20 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
return fmt.Errorf("provider settings: %w", err)
|
||||
}
|
||||
|
||||
if v.Type == vpn.OpenVPN {
|
||||
switch v.Type {
|
||||
case vpn.AmneziaWg:
|
||||
err = v.AmneziaWg.validate(v.Provider.Name, ipv6Supported)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AmneziaWG settings: %w", err)
|
||||
}
|
||||
case vpn.OpenVPN:
|
||||
err := v.OpenVPN.validate(v.Provider.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("OpenVPN settings: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported)
|
||||
case vpn.Wireguard:
|
||||
const amneziawg = false
|
||||
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported, amneziawg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Wireguard settings: %w", err)
|
||||
}
|
||||
@@ -66,6 +76,7 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
return VPN{
|
||||
Type: v.Type,
|
||||
Provider: v.Provider.copy(),
|
||||
AmneziaWg: v.AmneziaWg.copy(),
|
||||
OpenVPN: v.OpenVPN.copy(),
|
||||
Wireguard: v.Wireguard.copy(),
|
||||
PMTUD: v.PMTUD.copy(),
|
||||
@@ -77,6 +88,7 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
func (v *VPN) OverrideWith(other VPN) {
|
||||
v.Type = gosettings.OverrideWithComparable(v.Type, other.Type)
|
||||
v.Provider.overrideWith(other.Provider)
|
||||
v.AmneziaWg.overrideWith(other.AmneziaWg)
|
||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||
v.Wireguard.overrideWith(other.Wireguard)
|
||||
v.PMTUD.overrideWith(other.PMTUD)
|
||||
@@ -87,6 +99,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
||||
func (v *VPN) setDefaults() {
|
||||
v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN)
|
||||
v.Provider.setDefaults()
|
||||
v.AmneziaWg.setDefaults(v.Provider.Name)
|
||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||
v.Wireguard.setDefaults(v.Provider.Name)
|
||||
v.PMTUD.setDefaults()
|
||||
@@ -103,9 +116,12 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
||||
|
||||
node.AppendNode(v.Provider.toLinesNode())
|
||||
|
||||
if v.Type == vpn.OpenVPN {
|
||||
switch v.Type {
|
||||
case vpn.AmneziaWg:
|
||||
node.AppendNode(v.AmneziaWg.toLinesNode())
|
||||
case vpn.OpenVPN:
|
||||
node.AppendNode(v.OpenVPN.toLinesNode())
|
||||
} else {
|
||||
case vpn.Wireguard:
|
||||
node.AppendNode(v.Wireguard.toLinesNode())
|
||||
}
|
||||
node.AppendNode(v.PMTUD.toLinesNode())
|
||||
@@ -128,12 +144,18 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
||||
return fmt.Errorf("VPN provider: %w", err)
|
||||
}
|
||||
|
||||
err = v.AmneziaWg.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AmneziaWG: %w", err)
|
||||
}
|
||||
|
||||
err = v.OpenVPN.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("OpenVPN: %w", err)
|
||||
}
|
||||
|
||||
err = v.Wireguard.read(r)
|
||||
const amneziawg = false
|
||||
err = v.Wireguard.read(r, amneziawg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
@@ -42,34 +41,17 @@ type Wireguard struct {
|
||||
// 0 indicating to use PMTUD.
|
||||
MTU *uint32 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace", "kernelspace" or "amneziawg".
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
// It defaults to "auto" and cannot be the empty string
|
||||
// in the internal state.
|
||||
Implementation string `json:"implementation"`
|
||||
// AmneziaWG contains obfuscation parameters
|
||||
AmneziaWG AmneziaWg `json:"amneziawg"`
|
||||
}
|
||||
|
||||
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
|
||||
// Validate validates Wireguard settings.
|
||||
// It should only be ran if the VPN type chosen is Wireguard.
|
||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) {
|
||||
if !helpers.IsOneOf(vpnProvider,
|
||||
providers.Airvpn,
|
||||
providers.Custom,
|
||||
providers.Fastestvpn,
|
||||
providers.Ivpn,
|
||||
providers.Mullvad,
|
||||
providers.Nordvpn,
|
||||
providers.Protonvpn,
|
||||
providers.Surfshark,
|
||||
providers.Windscribe,
|
||||
) {
|
||||
// do not validate for VPN provider not supporting Wireguard
|
||||
return nil
|
||||
}
|
||||
|
||||
// It should only be ran if the VPN type chosen is Wireguard or AmneziaWg.
|
||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
||||
// Validate PrivateKey
|
||||
if *w.PrivateKey == "" {
|
||||
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
||||
@@ -138,14 +120,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
||||
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
||||
}
|
||||
|
||||
validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"}
|
||||
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
||||
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
||||
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
||||
}
|
||||
|
||||
err = w.AmneziaWG.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("amneziawg settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -161,7 +140,6 @@ func (w *Wireguard) copy() (copied Wireguard) {
|
||||
Interface: w.Interface,
|
||||
MTU: w.MTU,
|
||||
Implementation: w.Implementation,
|
||||
AmneziaWG: w.AmneziaWG.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +153,6 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
||||
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
||||
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
||||
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
||||
w.AmneziaWG.overrideWith(other.AmneziaWG)
|
||||
}
|
||||
|
||||
func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
@@ -200,7 +177,6 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||
w.AmneziaWG.setDefaults()
|
||||
}
|
||||
|
||||
func (w Wireguard) String() string {
|
||||
@@ -242,29 +218,27 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
if w.Implementation != "auto" {
|
||||
implNode := node.Appendf("Implementation: %s", w.Implementation)
|
||||
|
||||
if w.Implementation == "amneziawg" {
|
||||
implNode.AppendNode(w.AmneziaWG.toLinesNode())
|
||||
}
|
||||
node.Appendf("Implementation: %s", w.Implementation)
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
w.PrivateKey = r.Get("WIREGUARD_PRIVATE_KEY", reader.ForceLowercase(false))
|
||||
w.PreSharedKey = r.Get("WIREGUARD_PRESHARED_KEY", reader.ForceLowercase(false))
|
||||
func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) {
|
||||
prefix := "WIREGUARD"
|
||||
if amneziaWG {
|
||||
prefix = "AMNEZIAWG"
|
||||
}
|
||||
w.PrivateKey = r.Get(prefix+"_PRIVATE_KEY", reader.ForceLowercase(false))
|
||||
w.PreSharedKey = r.Get(prefix+"_PRESHARED_KEY", reader.ForceLowercase(false))
|
||||
w.Interface = r.String("VPN_INTERFACE",
|
||||
reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false))
|
||||
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
|
||||
reader.RetroKeys(prefix+"_INTERFACE"), reader.ForceLowercase(false))
|
||||
|
||||
err = w.AmneziaWG.read(r)
|
||||
if err != nil {
|
||||
return err
|
||||
if !amneziaWG {
|
||||
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
|
||||
}
|
||||
|
||||
addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS"))
|
||||
addressStrings := r.CSV(prefix+"_ADDRESSES", reader.RetroKeys(prefix+"_ADDRESS"))
|
||||
// WARNING: do not initialize w.Addresses to an empty slice
|
||||
// or the defaults for nordvpn will not work.
|
||||
for _, addressString := range addressStrings {
|
||||
@@ -279,17 +253,17 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
w.Addresses = append(w.Addresses, address)
|
||||
}
|
||||
|
||||
w.AllowedIPs, err = r.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
|
||||
w.AllowedIPs, err = r.CSVNetipPrefixes(prefix + "_ALLOWED_IPS")
|
||||
if err != nil {
|
||||
return err // already wrapped
|
||||
}
|
||||
|
||||
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
|
||||
w.PersistentKeepaliveInterval, err = r.DurationPtr(prefix + "_PERSISTENT_KEEPALIVE_INTERVAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
||||
w.MTU, err = r.Uint32Ptr(prefix + "_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
func (s *Source) lazyLoadAmneziawgConf() AmneziawgConfig {
|
||||
if s.cached.amneziawgLoaded {
|
||||
return s.cached.amneziawgConf
|
||||
}
|
||||
|
||||
s.cached.amneziawgLoaded = true
|
||||
var err error
|
||||
s.cached.amneziawgConf, err = ParseAmneziawgConf(filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf"))
|
||||
if err != nil {
|
||||
s.warner.Warnf("skipping Amneziawg config: %s", err)
|
||||
}
|
||||
return s.cached.amneziawgConf
|
||||
}
|
||||
|
||||
type AmneziawgConfig struct {
|
||||
Wireguard WireguardConfig
|
||||
Jc *string
|
||||
Jmin *string
|
||||
Jmax *string
|
||||
S1 *string
|
||||
S2 *string
|
||||
S3 *string
|
||||
S4 *string
|
||||
H1 *string
|
||||
H2 *string
|
||||
H3 *string
|
||||
H4 *string
|
||||
I1 *string
|
||||
I2 *string
|
||||
I3 *string
|
||||
I4 *string
|
||||
I5 *string
|
||||
}
|
||||
|
||||
func ParseAmneziawgConf(path string) (config AmneziawgConfig, err error) {
|
||||
iniFile, err := ini.InsensitiveLoad(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return AmneziawgConfig{}, nil
|
||||
}
|
||||
return AmneziawgConfig{}, fmt.Errorf("loading ini from reader: %w", err)
|
||||
}
|
||||
|
||||
config.Wireguard, err = ParseWireguardConf(path)
|
||||
if err != nil {
|
||||
return AmneziawgConfig{}, err
|
||||
}
|
||||
|
||||
interfaceSection, err := iniFile.GetSection("Interface")
|
||||
if err != nil {
|
||||
// can never happen
|
||||
return AmneziawgConfig{}, fmt.Errorf("getting interface section: %w", err)
|
||||
}
|
||||
|
||||
config.Jc = getINIKeyFromSection(interfaceSection, "Jc")
|
||||
config.Jmin = getINIKeyFromSection(interfaceSection, "Jmin")
|
||||
config.Jmax = getINIKeyFromSection(interfaceSection, "Jmax")
|
||||
config.S1 = getINIKeyFromSection(interfaceSection, "S1")
|
||||
config.S2 = getINIKeyFromSection(interfaceSection, "S2")
|
||||
config.S3 = getINIKeyFromSection(interfaceSection, "S3")
|
||||
config.S4 = getINIKeyFromSection(interfaceSection, "S4")
|
||||
config.H1 = getINIKeyFromSection(interfaceSection, "H1")
|
||||
config.H2 = getINIKeyFromSection(interfaceSection, "H2")
|
||||
config.H3 = getINIKeyFromSection(interfaceSection, "H3")
|
||||
config.H4 = getINIKeyFromSection(interfaceSection, "H4")
|
||||
config.I1 = getINIKeyFromSection(interfaceSection, "I1")
|
||||
config.I2 = getINIKeyFromSection(interfaceSection, "I2")
|
||||
config.I3 = getINIKeyFromSection(interfaceSection, "I3")
|
||||
config.I4 = getINIKeyFromSection(interfaceSection, "I4")
|
||||
config.I5 = getINIKeyFromSection(interfaceSection, "I5")
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Source_ParseAmneziawgConf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("no_file", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noFile := filepath.Join(t.TempDir(), "doesnotexist")
|
||||
wireguard, err := ParseAmneziawgConf(noFile)
|
||||
assert.Equal(t, AmneziawgConfig{}, wireguard)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
testCases := map[string]struct {
|
||||
fileContent string
|
||||
amneziawg AmneziawgConfig
|
||||
errMessage string
|
||||
}{
|
||||
"ini_load_error": {
|
||||
fileContent: "invalid",
|
||||
errMessage: "loading ini from reader: key-value delimiter not found: invalid",
|
||||
},
|
||||
"empty_file": {
|
||||
errMessage: `getting interface section: section "interface" does not exist`,
|
||||
},
|
||||
"success": {
|
||||
fileContent: `
|
||||
[Interface]
|
||||
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
|
||||
Address = 10.38.22.35/32
|
||||
DNS = 193.138.218.74
|
||||
Jc = 4
|
||||
H1 = 721391205
|
||||
I1 = <b 0x1234>
|
||||
|
||||
[Peer]
|
||||
PresharedKey = YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g=
|
||||
`,
|
||||
amneziawg: AmneziawgConfig{
|
||||
Wireguard: WireguardConfig{
|
||||
PrivateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
|
||||
PreSharedKey: ptrTo("YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g="),
|
||||
Addresses: ptrTo("10.38.22.35/32"),
|
||||
},
|
||||
Jc: ptrTo("4"),
|
||||
H1: ptrTo("721391205"),
|
||||
I1: ptrTo("<b 0x1234>"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testCase := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configFile := filepath.Join(t.TempDir(), "awg.conf")
|
||||
const permission = fs.FileMode(0o600)
|
||||
err := os.WriteFile(configFile, []byte(testCase.fileContent), permission)
|
||||
require.NoError(t, err)
|
||||
|
||||
wireguard, err := ParseAmneziawgConf(configFile)
|
||||
|
||||
assert.Equal(t, testCase.amneziawg, wireguard)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,8 @@ type Source struct {
|
||||
cached struct {
|
||||
wireguardLoaded bool
|
||||
wireguardConf WireguardConfig
|
||||
amneziawgLoaded bool
|
||||
amneziawgConf AmneziawgConfig
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,38 +71,11 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
||||
case "wireguard_endpoint_port":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
||||
case "wireguard_jc":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
|
||||
case "wireguard_jmin":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
|
||||
case "wireguard_jmax":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
|
||||
case "wireguard_s1":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
|
||||
case "wireguard_s2":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
|
||||
case "wireguard_s3":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
|
||||
case "wireguard_s4":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
|
||||
case "wireguard_h1":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
|
||||
case "wireguard_h2":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
|
||||
case "wireguard_h3":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
|
||||
case "wireguard_h4":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
|
||||
case "wireguard_i1":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
|
||||
case "wireguard_i2":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
|
||||
case "wireguard_i3":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
|
||||
case "wireguard_i4":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
|
||||
case "wireguard_i5":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
|
||||
}
|
||||
|
||||
value, isSet, matched := s.getAmneziawgKey(key)
|
||||
if matched {
|
||||
return value, isSet
|
||||
}
|
||||
|
||||
value, isSet, err := ReadFromFile(path)
|
||||
@@ -110,6 +85,58 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
||||
return value, isSet
|
||||
}
|
||||
|
||||
func (s *Source) getAmneziawgKey(key string) (value string, isSet, matched bool) {
|
||||
switch key {
|
||||
case "amnezia_private_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
|
||||
case "amnezia_preshared_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
|
||||
case "amnezia_addresses":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
|
||||
case "amnezia_public_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
|
||||
case "amnezia_endpoint_ip":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
|
||||
case "amnezia_endpoint_port":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
|
||||
case "amnezia_jc":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
|
||||
case "amnezia_jmin":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
|
||||
case "amnezia_jmax":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
|
||||
case "amnezia_s1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
|
||||
case "amnezia_s2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
|
||||
case "amnezia_s3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
|
||||
case "amnezia_s4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
|
||||
case "amnezia_h1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
|
||||
case "amnezia_h2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
|
||||
case "amnezia_h3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
|
||||
case "amnezia_h4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
|
||||
case "amnezia_i1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
|
||||
case "amnezia_i2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
|
||||
case "amnezia_i3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
|
||||
case "amnezia_i4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
|
||||
case "amnezia_i5":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
|
||||
default:
|
||||
return "", false, false
|
||||
}
|
||||
return value, isSet, true
|
||||
}
|
||||
|
||||
func (s *Source) KeyTransform(key string) string {
|
||||
switch key {
|
||||
// TODO v4 remove these irregular cases
|
||||
|
||||
@@ -25,46 +25,6 @@ func (s *Source) lazyLoadWireguardConf() WireguardConfig {
|
||||
return s.cached.wireguardConf
|
||||
}
|
||||
|
||||
type amneziaWgConfig struct {
|
||||
Jc *string
|
||||
Jmin *string
|
||||
Jmax *string
|
||||
S1 *string
|
||||
S2 *string
|
||||
S3 *string
|
||||
S4 *string
|
||||
H1 *string
|
||||
H2 *string
|
||||
H3 *string
|
||||
H4 *string
|
||||
I1 *string
|
||||
I2 *string
|
||||
I3 *string
|
||||
I4 *string
|
||||
I5 *string
|
||||
}
|
||||
|
||||
func parseWireguardAmneziaInterfaceSection(interfaceSection *ini.Section) amneziaWgConfig {
|
||||
return amneziaWgConfig{
|
||||
Jc: getINIKeyFromSection(interfaceSection, "Jc"),
|
||||
Jmin: getINIKeyFromSection(interfaceSection, "Jmin"),
|
||||
Jmax: getINIKeyFromSection(interfaceSection, "Jmax"),
|
||||
S1: getINIKeyFromSection(interfaceSection, "S1"),
|
||||
S2: getINIKeyFromSection(interfaceSection, "S2"),
|
||||
S3: getINIKeyFromSection(interfaceSection, "S3"),
|
||||
S4: getINIKeyFromSection(interfaceSection, "S4"),
|
||||
H1: getINIKeyFromSection(interfaceSection, "H1"),
|
||||
H2: getINIKeyFromSection(interfaceSection, "H2"),
|
||||
H3: getINIKeyFromSection(interfaceSection, "H3"),
|
||||
H4: getINIKeyFromSection(interfaceSection, "H4"),
|
||||
I1: getINIKeyFromSection(interfaceSection, "I1"),
|
||||
I2: getINIKeyFromSection(interfaceSection, "I2"),
|
||||
I3: getINIKeyFromSection(interfaceSection, "I3"),
|
||||
I4: getINIKeyFromSection(interfaceSection, "I4"),
|
||||
I5: getINIKeyFromSection(interfaceSection, "I5"),
|
||||
}
|
||||
}
|
||||
|
||||
type WireguardConfig struct {
|
||||
PrivateKey *string
|
||||
PreSharedKey *string
|
||||
@@ -72,7 +32,6 @@ type WireguardConfig struct {
|
||||
PublicKey *string
|
||||
EndpointIP *string
|
||||
EndpointPort *string
|
||||
AmneziaParams amneziaWgConfig
|
||||
}
|
||||
|
||||
var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`)
|
||||
@@ -89,7 +48,6 @@ func ParseWireguardConf(path string) (config WireguardConfig, err error) {
|
||||
interfaceSection, err := iniFile.GetSection("Interface")
|
||||
if err == nil {
|
||||
config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection)
|
||||
config.AmneziaParams = parseWireguardAmneziaInterfaceSection(interfaceSection)
|
||||
} else if !regexINISectionNotExist.MatchString(err.Error()) {
|
||||
// can never happen
|
||||
return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err)
|
||||
|
||||
@@ -100,7 +100,6 @@ func Test_parseWireguardInterfaceSection(t *testing.T) {
|
||||
iniData string
|
||||
privateKey *string
|
||||
addresses *string
|
||||
amneziaParams amneziaWgConfig
|
||||
}{
|
||||
"no_fields": {
|
||||
iniData: `[Interface]`,
|
||||
@@ -116,17 +115,9 @@ PrivateKey = x
|
||||
[Interface]
|
||||
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
|
||||
Address = 10.38.22.35/32
|
||||
Jc = 4
|
||||
H1 = 721391205
|
||||
I1 = <b 0x1234>
|
||||
`,
|
||||
privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
|
||||
addresses: ptrTo("10.38.22.35/32"),
|
||||
amneziaParams: amneziaWgConfig{
|
||||
Jc: ptrTo("4"),
|
||||
H1: ptrTo("721391205"),
|
||||
I1: ptrTo("<b 0x1234>"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -140,11 +131,9 @@ I1 = <b 0x1234>
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKey, addresses := parseWireguardInterfaceSection(iniSection)
|
||||
amneziaWgConfig := parseWireguardAmneziaInterfaceSection(iniSection)
|
||||
|
||||
assert.Equal(t, testCase.privateKey, privateKey)
|
||||
assert.Equal(t, testCase.addresses, addresses)
|
||||
assert.Equal(t, testCase.amneziaParams, amneziaWgConfig)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/sources/files"
|
||||
)
|
||||
|
||||
func (s *Source) lazyLoadAmneziawgConf() files.AmneziawgConfig {
|
||||
if s.cached.amneziawgLoaded {
|
||||
return s.cached.amneziawgConf
|
||||
}
|
||||
|
||||
path := os.Getenv("AMNEZIAWG_CONF_SECRETFILE")
|
||||
if path == "" {
|
||||
path = filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf")
|
||||
}
|
||||
|
||||
s.cached.amneziawgLoaded = true
|
||||
var err error
|
||||
s.cached.amneziawgConf, err = files.ParseAmneziawgConf(path)
|
||||
if err != nil {
|
||||
s.warner.Warnf("skipping Amneziawg config: %s", err)
|
||||
}
|
||||
return s.cached.amneziawgConf
|
||||
}
|
||||
@@ -15,6 +15,8 @@ type Source struct {
|
||||
cached struct {
|
||||
wireguardLoaded bool
|
||||
wireguardConf files.WireguardConfig
|
||||
amneziawgLoaded bool
|
||||
amneziawgConf files.AmneziawgConfig
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,26 +62,26 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
||||
s.warner.Warnf("skipping %s: parsing PEM: %s", path, err)
|
||||
}
|
||||
return value, isSet
|
||||
case "wireguard_private_key":
|
||||
case "wireguard_private_key", "amneziawg_private_key":
|
||||
privateKey := s.lazyLoadWireguardConf().PrivateKey
|
||||
if privateKey != nil {
|
||||
return *privateKey, true
|
||||
} // else continue to read from individual secret file
|
||||
case "wireguard_preshared_key":
|
||||
case "wireguard_preshared_key", "amneziawg_preshared_key":
|
||||
preSharedKey := s.lazyLoadWireguardConf().PreSharedKey
|
||||
if preSharedKey != nil {
|
||||
return *preSharedKey, true
|
||||
} // else continue to read from individual secret file
|
||||
case "wireguard_addresses":
|
||||
case "wireguard_addresses", "amneziawg_addresses":
|
||||
addresses := s.lazyLoadWireguardConf().Addresses
|
||||
if addresses != nil {
|
||||
return *addresses, true
|
||||
} // else continue to read from individual secret file
|
||||
case "wireguard_public_key":
|
||||
case "wireguard_public_key", "amneziawg_public_key":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().PublicKey)
|
||||
case "wireguard_endpoint_ip":
|
||||
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
||||
case "wireguard_endpoint_port":
|
||||
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
|
||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
||||
}
|
||||
|
||||
@@ -112,38 +114,50 @@ func (s *Source) KeyTransform(key string) string {
|
||||
|
||||
func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) {
|
||||
switch key {
|
||||
case "wireguard_jc":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
|
||||
case "wireguard_jmin":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
|
||||
case "wireguard_jmax":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
|
||||
case "wireguard_s1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
|
||||
case "wireguard_s2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
|
||||
case "wireguard_s3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
|
||||
case "wireguard_s4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
|
||||
case "wireguard_h1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
|
||||
case "wireguard_h2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
|
||||
case "wireguard_h3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
|
||||
case "wireguard_h4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
|
||||
case "wireguard_i1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
|
||||
case "wireguard_i2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
|
||||
case "wireguard_i3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
|
||||
case "wireguard_i4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
|
||||
case "wireguard_i5":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
|
||||
case "amneziawg_private_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
|
||||
case "amneziawg_preshared_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
|
||||
case "wireguard_addresses", "amneziawg_addresses":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
|
||||
case "wireguard_public_key", "amneziawg_public_key":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
|
||||
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
|
||||
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
|
||||
case "amneziawg_jc":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
|
||||
case "amneziawg_jmin":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
|
||||
case "amneziawg_jmax":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
|
||||
case "amneziawg_s1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
|
||||
case "amneziawg_s2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
|
||||
case "amneziawg_s3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
|
||||
case "amneziawg_s4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
|
||||
case "amneziawg_h1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
|
||||
case "amneziawg_h2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
|
||||
case "amneziawg_h3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
|
||||
case "amneziawg_h4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
|
||||
case "amneziawg_i1":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
|
||||
case "amneziawg_i2":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
|
||||
case "amneziawg_i3":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
|
||||
case "amneziawg_i4":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
|
||||
case "amneziawg_i5":
|
||||
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
|
||||
default:
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package vpn
|
||||
|
||||
const (
|
||||
AmneziaWg = "amneziawg"
|
||||
OpenVPN = "openvpn"
|
||||
Wireguard = "wireguard"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/amneziawg"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
"github.com/qdm12/gosettings"
|
||||
)
|
||||
|
||||
// setupAmneziaWg sets AmneziaWG up using the configurators and settings given.
|
||||
func setupAmneziaWg(ctx context.Context, netlinker NetLinker,
|
||||
fw Firewall, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||
amneziawger *amneziawg.Amneziawg, connection models.Connection, err error,
|
||||
) {
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||
}
|
||||
|
||||
amneziaWGSettings := buildAmneziaWgSettings(connection, settings.AmneziaWg, ipv6Supported)
|
||||
|
||||
logger.Debug("Amneziawg server public key: " + amneziaWGSettings.Wireguard.PublicKey)
|
||||
logger.Debug("Amneziawg client private key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PrivateKey))
|
||||
logger.Debug("Amneziawg pre-shared key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PreSharedKey))
|
||||
|
||||
amneziawger, err = amneziawg.New(amneziaWGSettings, netlinker, logger)
|
||||
if err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("creating amneziawg: %w", err)
|
||||
}
|
||||
|
||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||
if err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
|
||||
}
|
||||
|
||||
return amneziawger, connection, nil
|
||||
}
|
||||
|
||||
func buildAmneziaWgSettings(connection models.Connection,
|
||||
userSettings settings.AmneziaWg, ipv6Supported bool,
|
||||
) amneziawg.Settings {
|
||||
return amneziawg.Settings{
|
||||
Wireguard: buildWireguardSettings(connection, userSettings.Wireguard, ipv6Supported),
|
||||
JunkPacketCount: *userSettings.JunkPacketCount,
|
||||
JunkPacketMin: *userSettings.JunkPacketMin,
|
||||
JunkPacketMax: *userSettings.JunkPacketMax,
|
||||
PaddingS1: *userSettings.PaddingS1,
|
||||
PaddingS2: *userSettings.PaddingS2,
|
||||
PaddingS3: *userSettings.PaddingS3,
|
||||
PaddingS4: *userSettings.PaddingS4,
|
||||
HeaderH1: *userSettings.HeaderH1,
|
||||
HeaderH2: *userSettings.HeaderH2,
|
||||
HeaderH3: *userSettings.HeaderH3,
|
||||
HeaderH4: *userSettings.HeaderH4,
|
||||
InitPacketI1: *userSettings.InitPacketI1,
|
||||
InitPacketI2: *userSettings.InitPacketI2,
|
||||
InitPacketI3: *userSettings.InitPacketI3,
|
||||
InitPacketI4: *userSettings.InitPacketI4,
|
||||
InitPacketI5: *userSettings.InitPacketI5,
|
||||
}
|
||||
}
|
||||
+9
-2
@@ -33,14 +33,21 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
var connection models.Connection
|
||||
var err error
|
||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||
if settings.Type == vpn.OpenVPN {
|
||||
switch settings.Type {
|
||||
case vpn.AmneziaWg:
|
||||
vpnInterface = settings.AmneziaWg.Wireguard.Interface
|
||||
vpnRunner, connection, err = setupAmneziaWg(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
case vpn.OpenVPN:
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.cmder, subLogger)
|
||||
} else { // Wireguard
|
||||
case vpn.Wireguard:
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
default:
|
||||
panic("vpn type not implemented: " + settings.Type)
|
||||
}
|
||||
if err != nil {
|
||||
l.crashed(ctx, err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud"
|
||||
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
@@ -48,6 +49,14 @@ type tunnelUpPMTUDData struct {
|
||||
}
|
||||
|
||||
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
switch vpnType := l.GetSettings().Type; vpnType {
|
||||
case vpn.Wireguard, vpn.AmneziaWg:
|
||||
l.logger.Infof("%s setup is complete. "+
|
||||
"Note %s is a silent protocol and it may or may not work, without giving any error message. "+
|
||||
"Typically i/o timeout errors indicate the %s connection is not working.",
|
||||
vpnType, vpnType, vpnType)
|
||||
}
|
||||
|
||||
l.client.CloseIdleConnections()
|
||||
|
||||
for _, vpnPort := range l.vpnInputPorts {
|
||||
|
||||
@@ -51,7 +51,6 @@ func buildWireguardSettings(connection models.Connection,
|
||||
settings.PreSharedKey = *userSettings.PreSharedKey
|
||||
settings.InterfaceName = userSettings.Interface
|
||||
settings.Implementation = userSettings.Implementation
|
||||
settings.AmneziaWG = buildAmneziaWgSettings(userSettings.AmneziaWG)
|
||||
if *userSettings.MTU > 0 {
|
||||
settings.MTU = *userSettings.MTU
|
||||
} else {
|
||||
@@ -91,24 +90,3 @@ func buildWireguardSettings(connection models.Connection,
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func buildAmneziaWgSettings(s settings.AmneziaWg) wireguard.AmneziaSettings {
|
||||
return wireguard.AmneziaSettings{
|
||||
JunkPacketCount: *s.JunkPacketCount,
|
||||
JunkPacketMin: *s.JunkPacketMin,
|
||||
JunkPacketMax: *s.JunkPacketMax,
|
||||
PaddingS1: *s.PaddingS1,
|
||||
PaddingS2: *s.PaddingS2,
|
||||
PaddingS3: *s.PaddingS3,
|
||||
PaddingS4: *s.PaddingS4,
|
||||
HeaderH1: *s.HeaderH1,
|
||||
HeaderH2: *s.HeaderH2,
|
||||
HeaderH3: *s.HeaderH3,
|
||||
HeaderH4: *s.HeaderH4,
|
||||
InitPacketI1: *s.InitPacketI1,
|
||||
InitPacketI2: *s.InitPacketI2,
|
||||
InitPacketI3: *s.InitPacketI3,
|
||||
InitPacketI4: *s.InitPacketI4,
|
||||
InitPacketI5: *s.InitPacketI5,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,24 +40,6 @@ func Test_buildWireguardSettings(t *testing.T) {
|
||||
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||
Interface: "wg1",
|
||||
MTU: ptrTo(uint32(1000)),
|
||||
AmneziaWG: settings.AmneziaWg{
|
||||
JunkPacketCount: ptrTo(uint16(1)),
|
||||
JunkPacketMin: ptrTo(uint16(0)),
|
||||
JunkPacketMax: ptrTo(uint16(0)),
|
||||
PaddingS1: ptrTo(uint16(0)),
|
||||
PaddingS2: ptrTo(uint16(0)),
|
||||
PaddingS3: ptrTo(uint16(0)),
|
||||
PaddingS4: ptrTo(uint16(0)),
|
||||
HeaderH1: ptrTo("x"),
|
||||
HeaderH2: ptrTo(""),
|
||||
HeaderH3: ptrTo(""),
|
||||
HeaderH4: ptrTo(""),
|
||||
InitPacketI1: ptrTo(""),
|
||||
InitPacketI2: ptrTo(""),
|
||||
InitPacketI3: ptrTo(""),
|
||||
InitPacketI4: ptrTo(""),
|
||||
InitPacketI5: ptrTo(""),
|
||||
},
|
||||
},
|
||||
ipv6Supported: false,
|
||||
settings: wireguard.Settings{
|
||||
@@ -76,10 +58,6 @@ func Test_buildWireguardSettings(t *testing.T) {
|
||||
RulePriority: 101,
|
||||
IPv6: ptrTo(false),
|
||||
MTU: 1000,
|
||||
AmneziaWG: wireguard.AmneziaSettings{
|
||||
JunkPacketCount: 1,
|
||||
HeaderH1: "x",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,15 +5,16 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix,
|
||||
func AddAddresses(linkIndex uint32,
|
||||
addresses []netip.Prefix, ipv6 bool,
|
||||
netlink NetLinker,
|
||||
) (err error) {
|
||||
for _, address := range addresses {
|
||||
if !*w.settings.IPv6 && address.Addr().Is6() {
|
||||
if !ipv6 && address.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
|
||||
err = w.netlink.AddrReplace(linkIndex, address)
|
||||
err = netlink.AddrReplace(linkIndex, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||
err, address, linkIndex)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
func Test_AddAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
||||
@@ -21,13 +21,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
linkIndex uint32
|
||||
addrs []netip.Prefix
|
||||
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
||||
ipv6 bool
|
||||
netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
@@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(nil).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
},
|
||||
"first add error": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
Return(errDummy)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||
},
|
||||
"second add error": {
|
||||
linkIndex: 1,
|
||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
||||
ipv6: true,
|
||||
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
firstCall := netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetOne).
|
||||
@@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
netLinker.EXPECT().
|
||||
AddrReplace(linkIndex, ipNetTwo).
|
||||
Return(errDummy).After(firstCall)
|
||||
return &Wireguard{
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(true),
|
||||
},
|
||||
}
|
||||
return netLinker
|
||||
},
|
||||
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||
},
|
||||
"ignore IPv6": {
|
||||
addrs: []netip.Prefix{ipNetTwo},
|
||||
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
||||
return &Wireguard{
|
||||
settings: Settings{
|
||||
IPv6: ptrTo(false),
|
||||
},
|
||||
}
|
||||
netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker {
|
||||
return NewMockNetLinker(nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
||||
netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex)
|
||||
|
||||
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
||||
err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import "sort"
|
||||
|
||||
type closer struct {
|
||||
operation string
|
||||
step step
|
||||
close func() error
|
||||
closed bool
|
||||
}
|
||||
|
||||
type closers []closer
|
||||
|
||||
func (c *closers) add(operation string, step step,
|
||||
closeFunc func() error,
|
||||
) {
|
||||
closer := closer{
|
||||
operation: operation,
|
||||
step: step,
|
||||
close: closeFunc,
|
||||
}
|
||||
*c = append(*c, closer)
|
||||
}
|
||||
|
||||
func (c *closers) cleanup(logger Logger) {
|
||||
closers := *c
|
||||
|
||||
sort.Slice(closers, func(i, j int) bool {
|
||||
return closers[i].step < closers[j].step
|
||||
})
|
||||
|
||||
for i, closer := range closers {
|
||||
if closer.closed {
|
||||
continue
|
||||
}
|
||||
closers[i].closed = true
|
||||
logger.Debug(closer.operation + "...")
|
||||
err := closer.close()
|
||||
if err != nil {
|
||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type step int
|
||||
|
||||
const (
|
||||
// stepOne closes the wireguard controller client,
|
||||
// and removes the IP rule.
|
||||
stepOne step = iota
|
||||
// stepTwo closes the UAPI listener.
|
||||
stepTwo
|
||||
// stepThree closes the UAPI file.
|
||||
stepThree
|
||||
// stepFour shuts down the Wireguard link.
|
||||
stepFour
|
||||
// stepFive removes the Wireguard link.
|
||||
stepFive
|
||||
// stepSix closes the Wireguard device.
|
||||
stepSix
|
||||
// stepSeven closes the bind connection and the TUN device file.
|
||||
stepSeven
|
||||
)
|
||||
@@ -1,57 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_closers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||
var (
|
||||
AErr error
|
||||
BErr = errors.New("B failed")
|
||||
CErr = errors.New("C failed")
|
||||
)
|
||||
|
||||
var closers closers
|
||||
closers.add("closing A", stepFive, func() error {
|
||||
ACloseCalled = true
|
||||
return AErr
|
||||
})
|
||||
|
||||
closers.add("closing B", stepThree, func() error {
|
||||
BCloseCalled = true
|
||||
return BErr
|
||||
})
|
||||
|
||||
closers.add("closing C", stepTwo, func() error {
|
||||
CCloseCalled = true
|
||||
return CErr
|
||||
})
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
prevCall := logger.EXPECT().Debug("closing C...")
|
||||
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
|
||||
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
|
||||
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
|
||||
logger.EXPECT().Debug("closing A...").After(prevCall)
|
||||
|
||||
closers.cleanup(logger)
|
||||
|
||||
closers.cleanup(logger) // run twice should not close already closed
|
||||
|
||||
for _, closer := range closers {
|
||||
assert.True(t, closer.closed)
|
||||
}
|
||||
|
||||
assert.True(t, ACloseCalled)
|
||||
assert.True(t, BCloseCalled)
|
||||
assert.True(t, CCloseCalled)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type tunDevice interface {
|
||||
Close() error
|
||||
Name() (string, error)
|
||||
}
|
||||
|
||||
type bind interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type userspaceDevice interface {
|
||||
Close()
|
||||
Wait() chan struct{}
|
||||
IpcHandle(net.Conn)
|
||||
IpcSet(string) error
|
||||
}
|
||||
|
||||
type userSpaceBackend struct {
|
||||
createTun func(string, int) (tunDevice, error)
|
||||
createBind func() bind
|
||||
createDevice func(tunDevice, bind, Logger) userspaceDevice
|
||||
preStart func(userspaceDevice, Settings) error
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
deviceConfig, err := makeDeviceConfig(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making device configuration: %w", err)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
||||
|
||||
type Logger interface {
|
||||
@@ -7,5 +11,16 @@ type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
Errorf(format string, args ...interface{})
|
||||
Erroer
|
||||
}
|
||||
|
||||
type Erroer interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
|
||||
return &device.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
func Test_makeDeviceLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
|
||||
logger.EXPECT().Debugf("test %d", 1)
|
||||
deviceLogger.Verbosef("test %d", 1)
|
||||
|
||||
logger.EXPECT().Errorf("test %d", 2)
|
||||
deviceLogger.Errorf("test %d", 2)
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {}
|
||||
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
func Test_AddAddresses_Integration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
@@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
|
||||
const addIterations = 2 // initial + replace
|
||||
for range addIterations {
|
||||
err = wg.addAddresses(link.Index, addresses)
|
||||
err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink)
|
||||
require.NoError(t, err)
|
||||
|
||||
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||
@@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
func Test_AddRule_Integration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := netlink.New(&noopDebugLogger{})
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
logger: &noopDebugLogger{},
|
||||
}
|
||||
logger := &noopDebugLogger{}
|
||||
netlinker := netlink.New(logger)
|
||||
|
||||
// Unique combination for this test
|
||||
const rulePriority uint32 = 10000
|
||||
const firewallMark uint32 = 12345
|
||||
const family = netlink.FamilyV4
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
cleanup, err := AddRule(rulePriority,
|
||||
firewallMark, family, netlinker, logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := cleanup()
|
||||
@@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
require.True(t, ruleFound)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority,
|
||||
firewallMark, family)
|
||||
nilCleanup, err := AddRule(rulePriority,
|
||||
firewallMark, family, netlinker, logger)
|
||||
if nilCleanup != nil {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
|
||||
@@ -8,17 +8,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32,
|
||||
func AddRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
firewallMark uint32, netlinker NetLinker, logger Erroer,
|
||||
) (err error) {
|
||||
for _, dst := range destinations {
|
||||
err = w.addRoute(linkIndex, dst, firewallMark)
|
||||
err = addRoute(linkIndex, dst, firewallMark, netlinker)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
|
||||
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||
logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||
"Ignoring and continuing execution; "+
|
||||
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
||||
"Full error string: %s", err)
|
||||
@@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32,
|
||||
func addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
firewallMark uint32, netlinker NetLinker,
|
||||
) (err error) {
|
||||
family := netlink.FamilyV4
|
||||
if dst.Addr().Is6() {
|
||||
@@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
||||
Proto: netlink.ProtoStatic,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
err = netlinker.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"adding route for link with index %d, destination %s and table %d: %w",
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRoute(t *testing.T) {
|
||||
func Test_addRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const linkIndex = 88
|
||||
@@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
||||
err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
family uint8,
|
||||
func AddRule(rulePriority, firewallMark uint32, family uint8,
|
||||
netlinker NetLinker, logger Logger,
|
||||
) (cleanup func() error, err error) {
|
||||
rule := netlink.Rule{
|
||||
Priority: &rulePriority,
|
||||
@@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
||||
Flags: netlink.FlagInvert,
|
||||
Action: netlink.ActionToTable,
|
||||
}
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
if err := netlinker.RuleAdd(rule); err != nil {
|
||||
if strings.HasSuffix(err.Error(), "file exists") {
|
||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
|
||||
}
|
||||
return nil, fmt.Errorf("adding %s: %w", rule, err)
|
||||
}
|
||||
|
||||
cleanup = func() error {
|
||||
err := w.netlink.RuleDel(rule)
|
||||
err := netlinker.RuleDel(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting rule %s: %w", rule, err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
func Test_AddRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority uint32 = 987
|
||||
@@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
||||
Return(testCase.ruleAddErr)
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark, family)
|
||||
cleanup, err := AddRule(rulePriority, firewallMark, family,
|
||||
netLinker, nil)
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
|
||||
+85
-89
@@ -6,39 +6,33 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/cleanup"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
ErrFindDevice = errors.New("cannot find Wireguard device")
|
||||
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
|
||||
ErrWgctrlOpen = errors.New("cannot open wgctrl")
|
||||
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
|
||||
ErrAddAddress = errors.New("cannot add address to wireguard interface")
|
||||
ErrConfigure = errors.New("cannot configure wireguard interface")
|
||||
ErrDeviceInfo = errors.New("cannot get wireguard device information")
|
||||
ErrIfaceUp = errors.New("cannot set the interface to UP")
|
||||
ErrRouteAdd = errors.New("cannot add route for interface")
|
||||
ErrDeviceWaited = errors.New("device waited for")
|
||||
ErrKernelSupport = errors.New("kernel does not support Wireguard")
|
||||
ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG")
|
||||
errKernelSupport = errors.New("kernel does not support Wireguard")
|
||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||
errDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
// Run runs the wireguard interface and waits until the context is done, then it cleans up the
|
||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
||||
waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
userspaceBackend := defaultUserSpaceBackend()
|
||||
setupFunction := setupUserSpaceCommon
|
||||
setupFunction := setupUserSpace
|
||||
switch w.settings.Implementation {
|
||||
case "auto": //nolint:goconst
|
||||
if !kernelSupported {
|
||||
@@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
||||
case "userspace":
|
||||
case "kernelspace":
|
||||
if !kernelSupported {
|
||||
waitError <- fmt.Errorf("%w", ErrKernelSupport)
|
||||
waitError <- fmt.Errorf("%w", errKernelSupport)
|
||||
return
|
||||
}
|
||||
setupFunction = setupKernelSpace
|
||||
case "amneziawg":
|
||||
userspaceBackend = amneziaUserSpaceBackend()
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
|
||||
}
|
||||
|
||||
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
return setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger)
|
||||
}
|
||||
|
||||
Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger)
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{},
|
||||
setup func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error),
|
||||
settings Settings, netlinker NetLinker, logger Logger,
|
||||
) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
|
||||
waitError <- fmt.Errorf("opening wgctrl: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
var closers closers
|
||||
closers.add("closing controller client", stepOne, client.Close)
|
||||
var cleanups cleanup.Cleanups
|
||||
cleanups.Add("closing controller client", 1, client.Close)
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
defer cleanups.Cleanup(logger)
|
||||
|
||||
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend)
|
||||
linkIndex, waitAndCleanup, err := setup(ctx, &cleanups)
|
||||
if err != nil {
|
||||
waitError <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
||||
err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
waitError <- fmt.Errorf("adding addresses to interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Info("Connecting to " + w.settings.Endpoint.String())
|
||||
err = configureDevice(client, w.settings)
|
||||
logger.Info("Connecting to " + settings.Endpoint.String())
|
||||
err = ConfigureDevice(client, settings)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
|
||||
waitError <- fmt.Errorf("configuring interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = w.netlink.LinkSetUp(linkIndex)
|
||||
err = netlinker.LinkSetUp(linkIndex)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
waitError <- fmt.Errorf("setting the interface UP: %w", err)
|
||||
return
|
||||
}
|
||||
closers.add("shutting down link", stepFour, func() error {
|
||||
return w.netlink.LinkSetDown(linkIndex)
|
||||
cleanups.Add("shutting down link", 4, func() error {
|
||||
return netlinker.LinkSetDown(linkIndex)
|
||||
})
|
||||
|
||||
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||
err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
waitError <- fmt.Errorf("adding routes for interface: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *w.settings.IPv6 {
|
||||
if *settings.IPv6 {
|
||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, netlink.FamilyV6)
|
||||
ruleCleanup6, err := AddRule(settings.RulePriority,
|
||||
settings.FirewallMark, netlink.FamilyV6,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||
return
|
||||
}
|
||||
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
|
||||
cleanups.Add("removing IPv6 rule", 1, ruleCleanup6)
|
||||
}
|
||||
|
||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
||||
w.settings.FirewallMark, netlink.FamilyV4)
|
||||
ruleCleanup, err := AddRule(settings.RulePriority,
|
||||
settings.FirewallMark, netlink.FamilyV4,
|
||||
netlinker, logger)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("removing IPv4 rule", stepOne, ruleCleanup)
|
||||
w.logger.Info("Wireguard setup is complete. " +
|
||||
"Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " +
|
||||
"Typically i/o timeout errors indicate the Wireguard connection is not working.")
|
||||
cleanups.Add("removing IPv4 rule", 1, ruleCleanup)
|
||||
ready <- struct{}{}
|
||||
|
||||
waitError <- waitAndCleanup()
|
||||
}
|
||||
|
||||
type waitAndCleanupFunc func() error
|
||||
|
||||
func setupKernelSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger, _ Settings, _ userSpaceBackend) (
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
cleanups *cleanup.Cleanups, logger Logger) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
links, err := netLinker.LinkList()
|
||||
if err != nil {
|
||||
@@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context,
|
||||
}
|
||||
linkIndex, err = netLinker.LinkAdd(link)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||
return 0, nil, fmt.Errorf("adding link: %w", err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
cleanups.Add("deleting link", 5, func() error {
|
||||
return netLinker.LinkDel(linkIndex)
|
||||
})
|
||||
|
||||
waitAndCleanup = func() error {
|
||||
<-ctx.Done()
|
||||
closers.cleanup(logger)
|
||||
cleanups.Cleanup(logger)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return linkIndex, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func setupUserSpaceCommon(ctx context.Context,
|
||||
func setupUserSpace(ctx context.Context,
|
||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||
closers *closers, logger Logger,
|
||||
settings Settings, b userSpaceBackend,
|
||||
) (
|
||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
||||
cleanups *cleanup.Cleanups, logger Logger) (
|
||||
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||
) {
|
||||
tun, err := b.createTun(interfaceName, int(mtu))
|
||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||
cleanups.Add("closing TUN device", 7, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||
} else if tunName != interfaceName {
|
||||
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, interfaceName, tunName)
|
||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||
errTunNameMismatch, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
|
||||
}
|
||||
closers.add("deleting link", stepFive, func() error {
|
||||
cleanups.Add("deleting link", 5, func() error {
|
||||
return netLinker.LinkDel(link.Index)
|
||||
})
|
||||
|
||||
bind := b.createBind()
|
||||
bind := conn.NewDefaultBind()
|
||||
|
||||
closers.add("closing bind", stepSeven, bind.Close)
|
||||
cleanups.Add("closing bind", 7, bind.Close)
|
||||
|
||||
device := b.createDevice(tun, bind, logger)
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
device := device.NewDevice(tun, bind, deviceLogger)
|
||||
|
||||
closers.add("closing Wireguard device", stepSix, func() error {
|
||||
cleanups.Add("closing Wireguard device", 6, func() error {
|
||||
device.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := uapiOpen(interfaceName)
|
||||
uapiFile, err := UAPIOpen(interfaceName)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
|
||||
|
||||
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
||||
uapiListener, err := UAPIListen(interfaceName, uapiFile)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
|
||||
if b.preStart != nil {
|
||||
err = b.preStart(device, settings)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
}
|
||||
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
|
||||
|
||||
// acceptAndHandle exits when uapiListener is closed
|
||||
uapiAcceptErrorCh := make(chan error)
|
||||
@@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context,
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = ErrDeviceWaited
|
||||
err = errDeviceWaited
|
||||
}
|
||||
|
||||
closers.cleanup(logger)
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||
|
||||
@@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context,
|
||||
return link.Index, waitAndCleanup, nil
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device userspaceDevice,
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
uapiAcceptErrorCh chan<- error,
|
||||
) {
|
||||
for { // stopped by uapiFile.Close()
|
||||
|
||||
@@ -46,11 +46,8 @@ type Settings struct {
|
||||
// It defaults to false if left unset.
|
||||
IPv6 *bool
|
||||
// Implementation is the implementation to use.
|
||||
// It can be auto, kernelspace, userspace or amneziawg,
|
||||
// and defaults to auto.
|
||||
// It can be auto, kernelspace or userspace, and defaults to auto.
|
||||
Implementation string
|
||||
// AmneziaWG settings are extra obfuscation parameters
|
||||
AmneziaWG AmneziaSettings
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
@@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) {
|
||||
}
|
||||
|
||||
switch s.Implementation {
|
||||
case "auto", "kernelspace", "userspace", "amneziawg":
|
||||
case "auto", "kernelspace", "userspace":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
|
||||
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
wgconn "golang.zx2c4.com/wireguard/conn"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func defaultUserSpaceBackend() userSpaceBackend {
|
||||
return userSpaceBackend{
|
||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
||||
return wgtun.CreateTUN(name, mtu)
|
||||
},
|
||||
createBind: func() bind {
|
||||
return wgconn.NewDefaultBind()
|
||||
},
|
||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
||||
wgtun, _ := td.(wgtun.Device)
|
||||
wgBind, _ := b.(wgconn.Bind)
|
||||
wgLogger := wgdevice.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger)
|
||||
return device
|
||||
},
|
||||
preStart: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func amneziaUserSpaceBackend() userSpaceBackend {
|
||||
return userSpaceBackend{
|
||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
||||
return amneziatun.CreateTUN(name, mtu)
|
||||
},
|
||||
createBind: func() bind {
|
||||
return amneziaconn.NewDefaultBind()
|
||||
},
|
||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
||||
wgamneziaTun, _ := td.(amneziatun.Device)
|
||||
wgamneziaBind, _ := b.(amneziaconn.Bind)
|
||||
wgamneziaLogger := amneziadevice.Logger{
|
||||
Verbosef: logger.Debugf,
|
||||
Errorf: logger.Errorf,
|
||||
}
|
||||
device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger)
|
||||
return device
|
||||
},
|
||||
preStart: func(ud userspaceDevice, s Settings) error {
|
||||
uapiConfig := s.AmneziaWG.uapiConfig()
|
||||
err := ud.IpcSet(uapiConfig)
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
return ipc.UAPIOpen(name)
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
return ipc.UAPIListen(interfaceName, uapiFile)
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func uapiOpen(name string) (*os.File, error) {
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user