mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
chore!(amneziawg): refactor to be separate from wireguard
- amneziawg is now a VPN protocol and no longer a Wireguard implementation - Use it with VPN_TYPE=amneziawg - document AMNEZIAWG_* options in Dockerfile - document amneziawg support in readme - separate amneziawg settings and code from wireguard - re-use code from wireguard whenever possible
This commit is contained in:
@@ -60,6 +60,10 @@ linters:
|
|||||||
- linters:
|
- linters:
|
||||||
- lll
|
- lll
|
||||||
source: "^// https://.+$"
|
source: "^// https://.+$"
|
||||||
|
- linters:
|
||||||
|
- mnd
|
||||||
|
source: "^ cleanups\\.Add.+$"
|
||||||
|
path: internal\/(wireguard|amneziawg)\/run\.go
|
||||||
- linters:
|
- linters:
|
||||||
- err113
|
- err113
|
||||||
- mnd
|
- mnd
|
||||||
|
|||||||
+30
@@ -112,6 +112,36 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||||
WIREGUARD_MTU= \
|
WIREGUARD_MTU= \
|
||||||
WIREGUARD_IMPLEMENTATION=auto \
|
WIREGUARD_IMPLEMENTATION=auto \
|
||||||
|
# Amnezia
|
||||||
|
AMNEZIAWG_ENDPOINT_IP= \
|
||||||
|
AMNEZIAWG_ENDPOINT_PORT= \
|
||||||
|
AMNEZIAWG_CONF_SECRETFILE=/run/secrets/wg0.conf \
|
||||||
|
AMNEZIAWG_PRIVATE_KEY= \
|
||||||
|
AMNEZIAWG_PRIVATE_KEY_SECRETFILE=/run/secrets/wireguard_private_key \
|
||||||
|
AMNEZIAWG_PRESHARED_KEY= \
|
||||||
|
AMNEZIAWG_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
|
||||||
|
AMNEZIAWG_PUBLIC_KEY= \
|
||||||
|
AMNEZIAWG_ALLOWED_IPS= \
|
||||||
|
AMNEZIAWG_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||||
|
AMNEZIAWG_ADDRESSES= \
|
||||||
|
AMNEZIAWG_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||||
|
AMNEZIAWG_MTU= \
|
||||||
|
AMNEZIAWG_JC=0 \
|
||||||
|
AMNEZIAWG_JMIN=0 \
|
||||||
|
AMNEZIAWG_JMAX=0 \
|
||||||
|
AMNEZIAWG_S1=0 \
|
||||||
|
AMNEZIAWG_S2=0 \
|
||||||
|
AMNEZIAWG_S3=0 \
|
||||||
|
AMNEZIAWG_S4=0 \
|
||||||
|
AMNEZIAWG_H1= \
|
||||||
|
AMNEZIAWG_H2= \
|
||||||
|
AMNEZIAWG_H3= \
|
||||||
|
AMNEZIAWG_H4= \
|
||||||
|
AMNEZIAWG_I1= \
|
||||||
|
AMNEZIAWG_I2= \
|
||||||
|
AMNEZIAWG_I3= \
|
||||||
|
AMNEZIAWG_I4= \
|
||||||
|
AMNEZIAWG_I5= \
|
||||||
# Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg)
|
# Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg)
|
||||||
AMNEZIAWG_JC=0 \
|
AMNEZIAWG_JC=0 \
|
||||||
AMNEZIAWG_JMIN=0 \
|
AMNEZIAWG_JMIN=0 \
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
|||||||
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||||
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||||
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
|
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
|
||||||
|
- Supports AmneziaWG only with the custom provider for now
|
||||||
- DNS over TLS baked in with service provider(s) of your choice
|
- DNS over TLS baked in with service provider(s) of your choice
|
||||||
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
|
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
|
||||||
- Choose the vpn network protocol, `udp` or `tcp`
|
- Choose the vpn network protocol, `udp` or `tcp`
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
type Amneziawg struct {
|
||||||
|
logger Logger
|
||||||
|
settings Settings
|
||||||
|
netlink NetLinker
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(settings Settings, netlink NetLinker,
|
||||||
|
logger Logger,
|
||||||
|
) (a *Amneziawg, err error) {
|
||||||
|
settings.SetDefaults()
|
||||||
|
if err := settings.Check(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Amneziawg{
|
||||||
|
logger: logger,
|
||||||
|
settings: settings,
|
||||||
|
netlink: netlink,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_New(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||||
|
logger := NewMockLogger(nil)
|
||||||
|
netLinker := NewMockNetLinker(nil)
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
amneziawg *Amneziawg
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
"bad_settings": {
|
||||||
|
settings: Settings{
|
||||||
|
Wireguard: wireguard.Settings{
|
||||||
|
PrivateKey: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: wireguard.ErrPrivateKeyMissing,
|
||||||
|
},
|
||||||
|
"minimal valid settings": {
|
||||||
|
settings: Settings{
|
||||||
|
Wireguard: wireguard.Settings{
|
||||||
|
PrivateKey: validKeyString,
|
||||||
|
PublicKey: validKeyString,
|
||||||
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
|
||||||
|
Addresses: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
||||||
|
},
|
||||||
|
FirewallMark: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
amneziawg: &Amneziawg{
|
||||||
|
logger: logger,
|
||||||
|
netlink: netLinker,
|
||||||
|
settings: Settings{
|
||||||
|
Wireguard: wireguard.Settings{
|
||||||
|
InterfaceName: "wg0",
|
||||||
|
PrivateKey: validKeyString,
|
||||||
|
PublicKey: validKeyString,
|
||||||
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
Addresses: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
||||||
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
},
|
||||||
|
FirewallMark: 100,
|
||||||
|
MTU: device.DefaultMTU,
|
||||||
|
IPv6: ptrTo(false),
|
||||||
|
Implementation: "auto",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wireguard, err := New(testCase.settings, netLinker, logger)
|
||||||
|
|
||||||
|
if testCase.err != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.amneziawg, wireguard)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
func ptrTo[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=log_mock_test.go -package amneziawg . Logger
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(s string)
|
||||||
|
Debugf(format string, args ...interface{})
|
||||||
|
Info(s string)
|
||||||
|
Error(s string)
|
||||||
|
Errorf(format string, args ...interface{})
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: Logger)
|
||||||
|
|
||||||
|
// Package amneziawg is a generated GoMock package.
|
||||||
|
package amneziawg
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug mocks base method.
|
||||||
|
func (m *MockLogger) Debug(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Debug", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug indicates an expected call of Debug.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf mocks base method.
|
||||||
|
func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Debugf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf indicates an expected call of Debugf.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error mocks base method.
|
||||||
|
func (m *MockLogger) Error(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Error", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error indicates an expected call of Error.
|
||||||
|
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf mocks base method.
|
||||||
|
func (m *MockLogger) Errorf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Errorf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf indicates an expected call of Errorf.
|
||||||
|
func (mr *MockLoggerMockRecorder) Errorf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info mocks base method.
|
||||||
|
func (m *MockLogger) Info(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Info", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info indicates an expected call of Info.
|
||||||
|
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=netlinker_mock_test.go -package amneziawg . NetLinker
|
||||||
|
|
||||||
|
type NetLinker interface {
|
||||||
|
AddrReplace(linkIndex uint32, addr netip.Prefix) error
|
||||||
|
Router
|
||||||
|
Ruler
|
||||||
|
Linker
|
||||||
|
IsWireguardSupported() (ok bool, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Router interface {
|
||||||
|
RouteList(family uint8) (routes []netlink.Route, err error)
|
||||||
|
RouteAdd(route netlink.Route) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ruler interface {
|
||||||
|
RuleAdd(rule netlink.Rule) error
|
||||||
|
RuleDel(rule netlink.Rule) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Linker interface {
|
||||||
|
LinkAdd(link netlink.Link) (linkIndex uint32, err error)
|
||||||
|
LinkList() (links []netlink.Link, err error)
|
||||||
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
|
LinkSetUp(linkIndex uint32) error
|
||||||
|
LinkSetDown(linkIndex uint32) error
|
||||||
|
LinkDel(linkIndex uint32) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/amneziawg (interfaces: NetLinker)
|
||||||
|
|
||||||
|
// Package amneziawg is a generated GoMock package.
|
||||||
|
package amneziawg
|
||||||
|
|
||||||
|
import (
|
||||||
|
netip "net/netip"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
netlink "github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockNetLinker is a mock of NetLinker interface.
|
||||||
|
type MockNetLinker struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockNetLinkerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
|
||||||
|
type MockNetLinkerMockRecorder struct {
|
||||||
|
mock *MockNetLinker
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockNetLinker creates a new mock instance.
|
||||||
|
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
|
||||||
|
mock := &MockNetLinker{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockNetLinkerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddrReplace mocks base method.
|
||||||
|
func (m *MockNetLinker) AddrReplace(arg0 uint32, arg1 netip.Prefix) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddrReplace indicates an expected call of AddrReplace.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) AddrReplace(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrReplace", reflect.TypeOf((*MockNetLinker)(nil).AddrReplace), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWireguardSupported mocks base method.
|
||||||
|
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "IsWireguardSupported")
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAdd mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (uint32, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||||
|
ret0, _ := ret[0].(uint32)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAdd indicates an expected call of LinkAdd.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkByName mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkByName", arg0)
|
||||||
|
ret0, _ := ret[0].(netlink.Link)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkByName indicates an expected call of LinkByName.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkDel mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkDel(arg0 uint32) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkDel", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkDel indicates an expected call of LinkDel.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkList mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkList() ([]netlink.Link, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkList")
|
||||||
|
ret0, _ := ret[0].([]netlink.Link)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkList indicates an expected call of LinkList.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkSetDown mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkSetDown(arg0 uint32) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkSetDown indicates an expected call of LinkSetDown.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkSetUp mocks base method.
|
||||||
|
func (m *MockNetLinker) LinkSetUp(arg0 uint32) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLinker)(nil).LinkSetUp), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteAdd mocks base method.
|
||||||
|
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteAdd indicates an expected call of RouteAdd.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteList mocks base method.
|
||||||
|
func (m *MockNetLinker) RouteList(arg0 byte) ([]netlink.Route, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RouteList", arg0)
|
||||||
|
ret0, _ := ret[0].([]netlink.Route)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteList indicates an expected call of RouteList.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) RouteList(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLinker)(nil).RouteList), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleAdd mocks base method.
|
||||||
|
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleAdd indicates an expected call of RuleAdd.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleDel mocks base method.
|
||||||
|
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleDel indicates an expected call of RuleDel.
|
||||||
|
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
|
||||||
|
}
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
package amneziawg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
|
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
"github.com/qdm12/gluetun/internal/cleanup"
|
||||||
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||||
|
errDeviceWaited = errors.New("device waited for")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
||||||
|
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||||
|
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||||
|
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
|
||||||
|
// See https://github.com/amnezia-vpn/amneziawg-go/blob/master/main.go
|
||||||
|
func (a *Amneziawg) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||||
|
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||||
|
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||||
|
) {
|
||||||
|
return setupUserspace(ctx, a.settings.Wireguard.InterfaceName,
|
||||||
|
a.netlink, a.settings.Wireguard.MTU, cleanups, a.logger, a.settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
wireguard.Run(ctx, waitError, ready, setup, a.settings.Wireguard, a.netlink, a.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupUserspace(ctx context.Context,
|
||||||
|
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||||
|
cleanups *cleanup.Cleanups, logger Logger,
|
||||||
|
settings Settings,
|
||||||
|
) (
|
||||||
|
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||||
|
) {
|
||||||
|
tun, err := amneziatun.CreateTUN(interfaceName, int(mtu))
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanups.Add("closing TUN device", 7, tun.Close)
|
||||||
|
|
||||||
|
tunName, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||||
|
} else if tunName != interfaceName {
|
||||||
|
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||||
|
errTunNameMismatch, interfaceName, tunName)
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netLinker.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
|
||||||
|
}
|
||||||
|
cleanups.Add("deleting link", 5, func() error {
|
||||||
|
return netLinker.LinkDel(link.Index)
|
||||||
|
})
|
||||||
|
|
||||||
|
bind := amneziaconn.NewDefaultBind()
|
||||||
|
cleanups.Add("closing bind", 7, bind.Close)
|
||||||
|
|
||||||
|
deviceLogger := amneziadevice.Logger{
|
||||||
|
Verbosef: logger.Debugf,
|
||||||
|
Errorf: logger.Errorf,
|
||||||
|
}
|
||||||
|
device := amneziadevice.NewDevice(tun, bind, &deviceLogger)
|
||||||
|
|
||||||
|
cleanups.Add("closing Wireguard device", 6, func() error {
|
||||||
|
device.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
uapiFile, err := wireguard.UAPIOpen(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
|
||||||
|
}
|
||||||
|
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
|
||||||
|
|
||||||
|
uapiListener, err := wireguard.UAPIListen(interfaceName, uapiFile)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
|
||||||
|
}
|
||||||
|
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
|
||||||
|
|
||||||
|
uapiConfig := settings.uapiConfig()
|
||||||
|
err = device.IpcSet(uapiConfig)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("setting amneziawg uapi config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptAndHandle exits when uapiListener is closed
|
||||||
|
uapiAcceptErrorCh := make(chan error)
|
||||||
|
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
|
||||||
|
waitAndCleanup = func() error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case err = <-uapiAcceptErrorCh:
|
||||||
|
close(uapiAcceptErrorCh)
|
||||||
|
case <-device.Wait():
|
||||||
|
err = errDeviceWaited
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanups.Cleanup(logger)
|
||||||
|
|
||||||
|
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return link.Index, waitAndCleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func acceptAndHandle(uapi net.Listener, device *amneziadevice.Device,
|
||||||
|
uapiAcceptErrorCh chan<- error,
|
||||||
|
) {
|
||||||
|
for { // stopped by uapiFile.Close()
|
||||||
|
conn, err := uapi.Accept()
|
||||||
|
if err != nil {
|
||||||
|
uapiAcceptErrorCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go device.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,14 @@
|
|||||||
package wireguard
|
package amneziawg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AmneziaSettings struct {
|
type Settings struct {
|
||||||
|
Wireguard wireguard.Settings
|
||||||
JunkPacketCount uint16
|
JunkPacketCount uint16
|
||||||
JunkPacketMin uint16
|
JunkPacketMin uint16
|
||||||
JunkPacketMax uint16
|
JunkPacketMax uint16
|
||||||
@@ -24,7 +27,7 @@ type AmneziaSettings struct {
|
|||||||
InitPacketI5 string
|
InitPacketI5 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AmneziaSettings) uapiConfig() string {
|
func (s Settings) uapiConfig() string {
|
||||||
uintFields := map[string]uint16{
|
uintFields := map[string]uint16{
|
||||||
"jc": s.JunkPacketCount,
|
"jc": s.JunkPacketCount,
|
||||||
"jmin": s.JunkPacketMin,
|
"jmin": s.JunkPacketMin,
|
||||||
@@ -56,3 +59,11 @@ func (s AmneziaSettings) uapiConfig() string {
|
|||||||
}
|
}
|
||||||
return strings.Join(lines, "\n")
|
return strings.Join(lines, "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Settings) SetDefaults() {
|
||||||
|
s.Wireguard.SetDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Settings) Check() error {
|
||||||
|
return s.Wireguard.Check()
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package cleanup
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
type Cleanups []cleanup
|
||||||
|
|
||||||
|
type cleanup struct {
|
||||||
|
operation string
|
||||||
|
orderIndex uint
|
||||||
|
cleanup func() error
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a cleanup function to the list of cleanups, with a description of the
|
||||||
|
// operation being cleaned up, and an order index that determines the order in which
|
||||||
|
// the cleanup functions are run. The lower the order index, the earlier the cleanup
|
||||||
|
// function is run.
|
||||||
|
func (c *Cleanups) Add(operation string, orderIndex uint,
|
||||||
|
cleanupFunc func() error,
|
||||||
|
) {
|
||||||
|
closer := cleanup{
|
||||||
|
operation: operation,
|
||||||
|
orderIndex: orderIndex,
|
||||||
|
cleanup: cleanupFunc,
|
||||||
|
}
|
||||||
|
*c = append(*c, closer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup runs the cleanup functions in the order of their orderIndex,
|
||||||
|
// and logs any error that occurs during cleanup.
|
||||||
|
// It can also be re-called in case a cleanup fails, and already cleaned up
|
||||||
|
// functions will not be re-run.
|
||||||
|
func (c *Cleanups) Cleanup(logger Logger) {
|
||||||
|
closers := *c
|
||||||
|
|
||||||
|
sort.Slice(closers, func(i, j int) bool {
|
||||||
|
return closers[i].orderIndex < closers[j].orderIndex
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, closer := range closers {
|
||||||
|
if closer.done {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
closers[i].done = true
|
||||||
|
logger.Debug(closer.operation + "...")
|
||||||
|
err := closer.cleanup()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package cleanup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Cleanups(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||||
|
var (
|
||||||
|
AErr error
|
||||||
|
BErr = errors.New("B failed")
|
||||||
|
CErr = errors.New("C failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
var cleanups Cleanups
|
||||||
|
cleanups.Add("cleaning up A", 5, func() error {
|
||||||
|
ACloseCalled = true
|
||||||
|
return AErr
|
||||||
|
})
|
||||||
|
|
||||||
|
cleanups.Add("cleaning up B", 3, func() error {
|
||||||
|
BCloseCalled = true
|
||||||
|
return BErr
|
||||||
|
})
|
||||||
|
|
||||||
|
cleanups.Add("cleaning up C", 2, func() error {
|
||||||
|
CCloseCalled = true
|
||||||
|
return CErr
|
||||||
|
})
|
||||||
|
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
prevCall := logger.EXPECT().Debug("cleaning up C...")
|
||||||
|
prevCall = logger.EXPECT().Error("failed cleaning up C: C failed").After(prevCall)
|
||||||
|
prevCall = logger.EXPECT().Debug("cleaning up B...").After(prevCall)
|
||||||
|
prevCall = logger.EXPECT().Error("failed cleaning up B: B failed").After(prevCall)
|
||||||
|
logger.EXPECT().Debug("cleaning up A...").After(prevCall)
|
||||||
|
|
||||||
|
cleanups.Cleanup(logger)
|
||||||
|
|
||||||
|
cleanups.Cleanup(logger) // run twice should not close already closed
|
||||||
|
|
||||||
|
for _, cleanup := range cleanups {
|
||||||
|
assert.True(t, cleanup.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, ACloseCalled)
|
||||||
|
assert.True(t, BCloseCalled)
|
||||||
|
assert.True(t, CCloseCalled)
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package cleanup
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(string)
|
||||||
|
Error(string)
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package cleanup
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/cleanup (interfaces: Logger)
|
||||||
|
|
||||||
|
// Package cleanup is a generated GoMock package.
|
||||||
|
package cleanup
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug mocks base method.
|
||||||
|
func (m *MockLogger) Debug(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Debug", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug indicates an expected call of Debug.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error mocks base method.
|
||||||
|
func (m *MockLogger) Error(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Error", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error indicates an expected call of Error.
|
||||||
|
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||||
|
}
|
||||||
@@ -12,33 +12,42 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AmneziaWg struct {
|
type AmneziaWg struct {
|
||||||
JunkPacketCount *uint16 `json:"junk_packet_count"`
|
// Wireguard contains the configuration for Wireguard, given
|
||||||
JunkPacketMin *uint16 `json:"junk_packet_min"`
|
// AmneziaWg is based on Wireguard
|
||||||
JunkPacketMax *uint16 `json:"junk_packet_max"`
|
Wireguard Wireguard `json:"wireguard"`
|
||||||
PaddingS1 *uint16 `json:"padding_s1"`
|
JunkPacketCount *uint16 `json:"junk_packet_count"`
|
||||||
PaddingS2 *uint16 `json:"padding_s2"`
|
JunkPacketMin *uint16 `json:"junk_packet_min"`
|
||||||
PaddingS3 *uint16 `json:"padding_s3"`
|
JunkPacketMax *uint16 `json:"junk_packet_max"`
|
||||||
PaddingS4 *uint16 `json:"padding_s4"`
|
PaddingS1 *uint16 `json:"padding_s1"`
|
||||||
HeaderH1 *string `json:"header_h1"`
|
PaddingS2 *uint16 `json:"padding_s2"`
|
||||||
HeaderH2 *string `json:"header_h2"`
|
PaddingS3 *uint16 `json:"padding_s3"`
|
||||||
HeaderH3 *string `json:"header_h3"`
|
PaddingS4 *uint16 `json:"padding_s4"`
|
||||||
HeaderH4 *string `json:"header_h4"`
|
HeaderH1 *string `json:"header_h1"`
|
||||||
InitPacketI1 *string `json:"init_packet_i1"`
|
HeaderH2 *string `json:"header_h2"`
|
||||||
InitPacketI2 *string `json:"init_packet_i2"`
|
HeaderH3 *string `json:"header_h3"`
|
||||||
InitPacketI3 *string `json:"init_packet_i3"`
|
HeaderH4 *string `json:"header_h4"`
|
||||||
InitPacketI4 *string `json:"init_packet_i4"`
|
InitPacketI1 *string `json:"init_packet_i1"`
|
||||||
InitPacketI5 *string `json:"init_packet_i5"`
|
InitPacketI2 *string `json:"init_packet_i2"`
|
||||||
|
InitPacketI3 *string `json:"init_packet_i3"`
|
||||||
|
InitPacketI4 *string `json:"init_packet_i4"`
|
||||||
|
InitPacketI5 *string `json:"init_packet_i5"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
func (a *AmneziaWg) read(r *reader.Reader) (err error) {
|
||||||
|
const amneziawg = true
|
||||||
|
err = a.Wireguard.read(r, amneziawg)
|
||||||
|
if err != nil {
|
||||||
|
return err // do not wrap this error
|
||||||
|
}
|
||||||
|
|
||||||
uint16Fields := map[string]**uint16{
|
uint16Fields := map[string]**uint16{
|
||||||
"AMNEZIAWG_JC": &s.JunkPacketCount,
|
"AMNEZIAWG_JC": &a.JunkPacketCount,
|
||||||
"AMNEZIAWG_JMIN": &s.JunkPacketMin,
|
"AMNEZIAWG_JMIN": &a.JunkPacketMin,
|
||||||
"AMNEZIAWG_JMAX": &s.JunkPacketMax,
|
"AMNEZIAWG_JMAX": &a.JunkPacketMax,
|
||||||
"AMNEZIAWG_S1": &s.PaddingS1,
|
"AMNEZIAWG_S1": &a.PaddingS1,
|
||||||
"AMNEZIAWG_S2": &s.PaddingS2,
|
"AMNEZIAWG_S2": &a.PaddingS2,
|
||||||
"AMNEZIAWG_S3": &s.PaddingS3,
|
"AMNEZIAWG_S3": &a.PaddingS3,
|
||||||
"AMNEZIAWG_S4": &s.PaddingS4,
|
"AMNEZIAWG_S4": &a.PaddingS4,
|
||||||
}
|
}
|
||||||
for key, dst := range uint16Fields {
|
for key, dst := range uint16Fields {
|
||||||
*dst, err = r.Uint16Ptr(key)
|
*dst, err = r.Uint16Ptr(key)
|
||||||
@@ -47,15 +56,15 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
stringFields := map[string]**string{
|
stringFields := map[string]**string{
|
||||||
"AMNEZIAWG_H1": &s.HeaderH1,
|
"AMNEZIAWG_H1": &a.HeaderH1,
|
||||||
"AMNEZIAWG_H2": &s.HeaderH2,
|
"AMNEZIAWG_H2": &a.HeaderH2,
|
||||||
"AMNEZIAWG_H3": &s.HeaderH3,
|
"AMNEZIAWG_H3": &a.HeaderH3,
|
||||||
"AMNEZIAWG_H4": &s.HeaderH4,
|
"AMNEZIAWG_H4": &a.HeaderH4,
|
||||||
"AMNEZIAWG_I1": &s.InitPacketI1,
|
"AMNEZIAWG_I1": &a.InitPacketI1,
|
||||||
"AMNEZIAWG_I2": &s.InitPacketI2,
|
"AMNEZIAWG_I2": &a.InitPacketI2,
|
||||||
"AMNEZIAWG_I3": &s.InitPacketI3,
|
"AMNEZIAWG_I3": &a.InitPacketI3,
|
||||||
"AMNEZIAWG_I4": &s.InitPacketI4,
|
"AMNEZIAWG_I4": &a.InitPacketI4,
|
||||||
"AMNEZIAWG_I5": &s.InitPacketI5,
|
"AMNEZIAWG_I5": &a.InitPacketI5,
|
||||||
}
|
}
|
||||||
opt := reader.ForceLowercase(false)
|
opt := reader.ForceLowercase(false)
|
||||||
for key, dst := range stringFields {
|
for key, dst := range stringFields {
|
||||||
@@ -64,80 +73,84 @@ func (s *AmneziaWg) read(r *reader.Reader) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AmneziaWg) copy() (copied AmneziaWg) {
|
func (a AmneziaWg) copy() (copied AmneziaWg) {
|
||||||
return AmneziaWg{
|
return AmneziaWg{
|
||||||
JunkPacketCount: gosettings.CopyPointer(s.JunkPacketCount),
|
Wireguard: a.Wireguard.copy(),
|
||||||
JunkPacketMin: gosettings.CopyPointer(s.JunkPacketMin),
|
JunkPacketCount: gosettings.CopyPointer(a.JunkPacketCount),
|
||||||
JunkPacketMax: gosettings.CopyPointer(s.JunkPacketMax),
|
JunkPacketMin: gosettings.CopyPointer(a.JunkPacketMin),
|
||||||
PaddingS1: gosettings.CopyPointer(s.PaddingS1),
|
JunkPacketMax: gosettings.CopyPointer(a.JunkPacketMax),
|
||||||
PaddingS2: gosettings.CopyPointer(s.PaddingS2),
|
PaddingS1: gosettings.CopyPointer(a.PaddingS1),
|
||||||
PaddingS3: gosettings.CopyPointer(s.PaddingS3),
|
PaddingS2: gosettings.CopyPointer(a.PaddingS2),
|
||||||
PaddingS4: gosettings.CopyPointer(s.PaddingS4),
|
PaddingS3: gosettings.CopyPointer(a.PaddingS3),
|
||||||
HeaderH1: gosettings.CopyPointer(s.HeaderH1),
|
PaddingS4: gosettings.CopyPointer(a.PaddingS4),
|
||||||
HeaderH2: gosettings.CopyPointer(s.HeaderH2),
|
HeaderH1: gosettings.CopyPointer(a.HeaderH1),
|
||||||
HeaderH3: gosettings.CopyPointer(s.HeaderH3),
|
HeaderH2: gosettings.CopyPointer(a.HeaderH2),
|
||||||
HeaderH4: gosettings.CopyPointer(s.HeaderH4),
|
HeaderH3: gosettings.CopyPointer(a.HeaderH3),
|
||||||
InitPacketI1: gosettings.CopyPointer(s.InitPacketI1),
|
HeaderH4: gosettings.CopyPointer(a.HeaderH4),
|
||||||
InitPacketI2: gosettings.CopyPointer(s.InitPacketI2),
|
InitPacketI1: gosettings.CopyPointer(a.InitPacketI1),
|
||||||
InitPacketI3: gosettings.CopyPointer(s.InitPacketI3),
|
InitPacketI2: gosettings.CopyPointer(a.InitPacketI2),
|
||||||
InitPacketI4: gosettings.CopyPointer(s.InitPacketI4),
|
InitPacketI3: gosettings.CopyPointer(a.InitPacketI3),
|
||||||
InitPacketI5: gosettings.CopyPointer(s.InitPacketI5),
|
InitPacketI4: gosettings.CopyPointer(a.InitPacketI4),
|
||||||
|
InitPacketI5: gosettings.CopyPointer(a.InitPacketI5),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:dupl
|
func (a *AmneziaWg) overrideWith(other AmneziaWg) {
|
||||||
func (s *AmneziaWg) overrideWith(other AmneziaWg) {
|
a.Wireguard.overrideWith(other.Wireguard)
|
||||||
s.JunkPacketCount = gosettings.OverrideWithPointer(s.JunkPacketCount, other.JunkPacketCount)
|
a.JunkPacketCount = gosettings.OverrideWithPointer(a.JunkPacketCount, other.JunkPacketCount)
|
||||||
s.JunkPacketMin = gosettings.OverrideWithPointer(s.JunkPacketMin, other.JunkPacketMin)
|
a.JunkPacketMin = gosettings.OverrideWithPointer(a.JunkPacketMin, other.JunkPacketMin)
|
||||||
s.JunkPacketMax = gosettings.OverrideWithPointer(s.JunkPacketMax, other.JunkPacketMax)
|
a.JunkPacketMax = gosettings.OverrideWithPointer(a.JunkPacketMax, other.JunkPacketMax)
|
||||||
s.PaddingS1 = gosettings.OverrideWithPointer(s.PaddingS1, other.PaddingS1)
|
a.PaddingS1 = gosettings.OverrideWithPointer(a.PaddingS1, other.PaddingS1)
|
||||||
s.PaddingS2 = gosettings.OverrideWithPointer(s.PaddingS2, other.PaddingS2)
|
a.PaddingS2 = gosettings.OverrideWithPointer(a.PaddingS2, other.PaddingS2)
|
||||||
s.PaddingS3 = gosettings.OverrideWithPointer(s.PaddingS3, other.PaddingS3)
|
a.PaddingS3 = gosettings.OverrideWithPointer(a.PaddingS3, other.PaddingS3)
|
||||||
s.PaddingS4 = gosettings.OverrideWithPointer(s.PaddingS4, other.PaddingS4)
|
a.PaddingS4 = gosettings.OverrideWithPointer(a.PaddingS4, other.PaddingS4)
|
||||||
s.HeaderH1 = gosettings.OverrideWithPointer(s.HeaderH1, other.HeaderH1)
|
a.HeaderH1 = gosettings.OverrideWithPointer(a.HeaderH1, other.HeaderH1)
|
||||||
s.HeaderH2 = gosettings.OverrideWithPointer(s.HeaderH2, other.HeaderH2)
|
a.HeaderH2 = gosettings.OverrideWithPointer(a.HeaderH2, other.HeaderH2)
|
||||||
s.HeaderH3 = gosettings.OverrideWithPointer(s.HeaderH3, other.HeaderH3)
|
a.HeaderH3 = gosettings.OverrideWithPointer(a.HeaderH3, other.HeaderH3)
|
||||||
s.HeaderH4 = gosettings.OverrideWithPointer(s.HeaderH4, other.HeaderH4)
|
a.HeaderH4 = gosettings.OverrideWithPointer(a.HeaderH4, other.HeaderH4)
|
||||||
s.InitPacketI1 = gosettings.OverrideWithPointer(s.InitPacketI1, other.InitPacketI1)
|
a.InitPacketI1 = gosettings.OverrideWithPointer(a.InitPacketI1, other.InitPacketI1)
|
||||||
s.InitPacketI2 = gosettings.OverrideWithPointer(s.InitPacketI2, other.InitPacketI2)
|
a.InitPacketI2 = gosettings.OverrideWithPointer(a.InitPacketI2, other.InitPacketI2)
|
||||||
s.InitPacketI3 = gosettings.OverrideWithPointer(s.InitPacketI3, other.InitPacketI3)
|
a.InitPacketI3 = gosettings.OverrideWithPointer(a.InitPacketI3, other.InitPacketI3)
|
||||||
s.InitPacketI4 = gosettings.OverrideWithPointer(s.InitPacketI4, other.InitPacketI4)
|
a.InitPacketI4 = gosettings.OverrideWithPointer(a.InitPacketI4, other.InitPacketI4)
|
||||||
s.InitPacketI5 = gosettings.OverrideWithPointer(s.InitPacketI5, other.InitPacketI5)
|
a.InitPacketI5 = gosettings.OverrideWithPointer(a.InitPacketI5, other.InitPacketI5)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AmneziaWg) setDefaults() {
|
func (a *AmneziaWg) setDefaults(vpnProvider string) {
|
||||||
s.JunkPacketCount = gosettings.DefaultPointer(s.JunkPacketCount, 0)
|
a.Wireguard.setDefaults(vpnProvider)
|
||||||
s.JunkPacketMin = gosettings.DefaultPointer(s.JunkPacketMin, 0)
|
a.Wireguard.Implementation = "userspace" // unused except in logs
|
||||||
s.JunkPacketMax = gosettings.DefaultPointer(s.JunkPacketMax, 0)
|
a.JunkPacketCount = gosettings.DefaultPointer(a.JunkPacketCount, 0)
|
||||||
s.PaddingS1 = gosettings.DefaultPointer(s.PaddingS1, 0)
|
a.JunkPacketMin = gosettings.DefaultPointer(a.JunkPacketMin, 0)
|
||||||
s.PaddingS2 = gosettings.DefaultPointer(s.PaddingS2, 0)
|
a.JunkPacketMax = gosettings.DefaultPointer(a.JunkPacketMax, 0)
|
||||||
s.PaddingS3 = gosettings.DefaultPointer(s.PaddingS3, 0)
|
a.PaddingS1 = gosettings.DefaultPointer(a.PaddingS1, 0)
|
||||||
s.PaddingS4 = gosettings.DefaultPointer(s.PaddingS4, 0)
|
a.PaddingS2 = gosettings.DefaultPointer(a.PaddingS2, 0)
|
||||||
s.HeaderH1 = gosettings.DefaultPointer(s.HeaderH1, "")
|
a.PaddingS3 = gosettings.DefaultPointer(a.PaddingS3, 0)
|
||||||
s.HeaderH2 = gosettings.DefaultPointer(s.HeaderH2, "")
|
a.PaddingS4 = gosettings.DefaultPointer(a.PaddingS4, 0)
|
||||||
s.HeaderH3 = gosettings.DefaultPointer(s.HeaderH3, "")
|
a.HeaderH1 = gosettings.DefaultPointer(a.HeaderH1, "")
|
||||||
s.HeaderH4 = gosettings.DefaultPointer(s.HeaderH4, "")
|
a.HeaderH2 = gosettings.DefaultPointer(a.HeaderH2, "")
|
||||||
s.InitPacketI1 = gosettings.DefaultPointer(s.InitPacketI1, "")
|
a.HeaderH3 = gosettings.DefaultPointer(a.HeaderH3, "")
|
||||||
s.InitPacketI2 = gosettings.DefaultPointer(s.InitPacketI2, "")
|
a.HeaderH4 = gosettings.DefaultPointer(a.HeaderH4, "")
|
||||||
s.InitPacketI3 = gosettings.DefaultPointer(s.InitPacketI3, "")
|
a.InitPacketI1 = gosettings.DefaultPointer(a.InitPacketI1, "")
|
||||||
s.InitPacketI4 = gosettings.DefaultPointer(s.InitPacketI4, "")
|
a.InitPacketI2 = gosettings.DefaultPointer(a.InitPacketI2, "")
|
||||||
s.InitPacketI5 = gosettings.DefaultPointer(s.InitPacketI5, "")
|
a.InitPacketI3 = gosettings.DefaultPointer(a.InitPacketI3, "")
|
||||||
|
a.InitPacketI4 = gosettings.DefaultPointer(a.InitPacketI4, "")
|
||||||
|
a.InitPacketI5 = gosettings.DefaultPointer(a.InitPacketI5, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||||
node = gotree.New("Amneziawg parameters:")
|
node = gotree.New("AmneziaWG settings:")
|
||||||
|
node.AppendNode(a.Wireguard.toLinesNode())
|
||||||
|
|
||||||
uintFields := []struct {
|
uintFields := []struct {
|
||||||
key string
|
key string
|
||||||
val *uint16
|
val *uint16
|
||||||
}{
|
}{
|
||||||
{"jc", s.JunkPacketCount},
|
{"JC", a.JunkPacketCount},
|
||||||
{"jmin", s.JunkPacketMin},
|
{"JMIN", a.JunkPacketMin},
|
||||||
{"jmax", s.JunkPacketMax},
|
{"JMAX", a.JunkPacketMax},
|
||||||
{"s1", s.PaddingS1},
|
{"S1", a.PaddingS1},
|
||||||
{"s2", s.PaddingS2},
|
{"S2", a.PaddingS2},
|
||||||
{"s3", s.PaddingS3},
|
{"S3", a.PaddingS3},
|
||||||
{"s4", s.PaddingS4},
|
{"S4", a.PaddingS4},
|
||||||
}
|
}
|
||||||
for _, f := range uintFields {
|
for _, f := range uintFields {
|
||||||
node.Appendf("%s: %d", f.key, *f.val)
|
node.Appendf("%s: %d", f.key, *f.val)
|
||||||
@@ -147,15 +160,15 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
|||||||
key string
|
key string
|
||||||
val *string
|
val *string
|
||||||
}{
|
}{
|
||||||
{"h1", s.HeaderH1},
|
{"H1", a.HeaderH1},
|
||||||
{"h2", s.HeaderH2},
|
{"H2", a.HeaderH2},
|
||||||
{"h3", s.HeaderH3},
|
{"H3", a.HeaderH3},
|
||||||
{"h4", s.HeaderH4},
|
{"H4", a.HeaderH4},
|
||||||
{"i1", s.InitPacketI1},
|
{"I1", a.InitPacketI1},
|
||||||
{"i2", s.InitPacketI2},
|
{"I2", a.InitPacketI2},
|
||||||
{"i3", s.InitPacketI3},
|
{"I3", a.InitPacketI3},
|
||||||
{"i4", s.InitPacketI4},
|
{"I4", a.InitPacketI4},
|
||||||
{"i5", s.InitPacketI5},
|
{"I5", a.InitPacketI5},
|
||||||
}
|
}
|
||||||
for _, f := range stringFields {
|
for _, f := range stringFields {
|
||||||
node.Appendf("%s: %s", f.key, *f.val)
|
node.Appendf("%s: %s", f.key, *f.val)
|
||||||
@@ -165,33 +178,40 @@ func (s AmneziaWg) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
|
||||||
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
||||||
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
||||||
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
||||||
|
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s AmneziaWg) validate() error {
|
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||||
if *s.JunkPacketCount == 0 {
|
const amneziaWG = true
|
||||||
if *s.JunkPacketMin != 0 || *s.JunkPacketMax != 0 {
|
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("wireguard settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if *a.JunkPacketCount == 0 {
|
||||||
|
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
||||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||||
ErrJunkPacketCountNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
|
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if *s.JunkPacketMin == 0 || *s.JunkPacketMax == 0 {
|
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
||||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||||
ErrJunkPacketMinMaxNotSet, s.JunkPacketCount, *s.JunkPacketMin, *s.JunkPacketMax)
|
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
} else if *s.JunkPacketMin > *s.JunkPacketMax {
|
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
||||||
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
||||||
ErrJunkPacketBounds, *s.JunkPacketMin, *s.JunkPacketMax)
|
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nameToHeaderRange := map[string]string{
|
nameToHeaderRange := map[string]string{
|
||||||
"h1": *s.HeaderH1,
|
"h1": *a.HeaderH1,
|
||||||
"h2": *s.HeaderH2,
|
"h2": *a.HeaderH2,
|
||||||
"h3": *s.HeaderH3,
|
"h3": *a.HeaderH3,
|
||||||
"h4": *s.HeaderH4,
|
"h4": *a.HeaderH4,
|
||||||
}
|
}
|
||||||
for name, headerRange := range nameToHeaderRange {
|
for name, headerRange := range nameToHeaderRange {
|
||||||
if headerRange == "" {
|
if headerRange == "" {
|
||||||
|
|||||||
@@ -268,8 +268,6 @@ func (o *OpenVPN) copy() (copied OpenVPN) {
|
|||||||
// overrideWith overrides fields of the receiver
|
// overrideWith overrides fields of the receiver
|
||||||
// settings object with any field set in the other
|
// settings object with any field set in the other
|
||||||
// settings.
|
// settings.
|
||||||
//
|
|
||||||
//nolint:dupl
|
|
||||||
func (o *OpenVPN) overrideWith(other OpenVPN) {
|
func (o *OpenVPN) overrideWith(other OpenVPN) {
|
||||||
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
|
o.Version = gosettings.OverrideWithComparable(o.Version, other.Version)
|
||||||
o.User = gosettings.OverrideWithPointer(o.User, other.User)
|
o.User = gosettings.OverrideWithPointer(o.User, other.User)
|
||||||
|
|||||||
@@ -30,7 +30,10 @@ type Provider struct {
|
|||||||
func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) {
|
func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGetter, warner Warner) (err error) {
|
||||||
// Validate Name
|
// Validate Name
|
||||||
var validNames []string
|
var validNames []string
|
||||||
if vpnType == vpn.OpenVPN {
|
switch vpnType {
|
||||||
|
case vpn.AmneziaWg:
|
||||||
|
validNames = []string{providers.Custom}
|
||||||
|
case vpn.OpenVPN:
|
||||||
validNames = providers.AllWithCustom()
|
validNames = providers.AllWithCustom()
|
||||||
validNames = append(validNames, "pia") // Retro-compatibility
|
validNames = append(validNames, "pia") // Retro-compatibility
|
||||||
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
|
// Remove Mullvad since it no longer supports OpenVPN as of January 15th, 2026
|
||||||
@@ -38,7 +41,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
|||||||
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
|
validNames[mullvadIndex], validNames[len(validNames)-1] = validNames[len(validNames)-1], validNames[mullvadIndex]
|
||||||
validNames = validNames[:len(validNames)-1]
|
validNames = validNames[:len(validNames)-1]
|
||||||
sort.Strings(validNames)
|
sort.Strings(validNames)
|
||||||
} else { // Wireguard
|
case vpn.Wireguard:
|
||||||
validNames = []string{
|
validNames = []string{
|
||||||
providers.Airvpn,
|
providers.Airvpn,
|
||||||
providers.Custom,
|
providers.Custom,
|
||||||
@@ -52,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
||||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
|||||||
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
switch ss.VPN {
|
switch ss.VPN {
|
||||||
case vpn.OpenVPN, vpn.Wireguard:
|
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ type VPN struct {
|
|||||||
// empty string in the internal state.
|
// empty string in the internal state.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Provider Provider `json:"provider"`
|
Provider Provider `json:"provider"`
|
||||||
|
AmneziaWg AmneziaWg `json:"amneziawg"`
|
||||||
OpenVPN OpenVPN `json:"openvpn"`
|
OpenVPN OpenVPN `json:"openvpn"`
|
||||||
Wireguard Wireguard `json:"wireguard"`
|
Wireguard Wireguard `json:"wireguard"`
|
||||||
PMTUD PMTUD `json:"pmtud"`
|
PMTUD PMTUD `json:"pmtud"`
|
||||||
@@ -29,10 +30,12 @@ type VPN struct {
|
|||||||
DownCommand *string `json:"down_command"`
|
DownCommand *string `json:"down_command"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate validates VPN settings, using the filter choices getter (aka servers data storage),
|
||||||
|
// and if IPv6 is supported or not.
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) {
|
func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bool, warner Warner) (err error) {
|
||||||
// Validate Type
|
// Validate Type
|
||||||
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
|
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
||||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
||||||
}
|
}
|
||||||
@@ -42,13 +45,20 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
|||||||
return fmt.Errorf("provider settings: %w", err)
|
return fmt.Errorf("provider settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.Type == vpn.OpenVPN {
|
switch v.Type {
|
||||||
|
case vpn.AmneziaWg:
|
||||||
|
err = v.AmneziaWg.validate(v.Provider.Name, ipv6Supported)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("AmneziaWG settings: %w", err)
|
||||||
|
}
|
||||||
|
case vpn.OpenVPN:
|
||||||
err := v.OpenVPN.validate(v.Provider.Name)
|
err := v.OpenVPN.validate(v.Provider.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("OpenVPN settings: %w", err)
|
return fmt.Errorf("OpenVPN settings: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
case vpn.Wireguard:
|
||||||
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported)
|
const amneziawg = false
|
||||||
|
err := v.Wireguard.validate(v.Provider.Name, ipv6Supported, amneziawg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Wireguard settings: %w", err)
|
return fmt.Errorf("Wireguard settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -66,6 +76,7 @@ func (v *VPN) Copy() (copied VPN) {
|
|||||||
return VPN{
|
return VPN{
|
||||||
Type: v.Type,
|
Type: v.Type,
|
||||||
Provider: v.Provider.copy(),
|
Provider: v.Provider.copy(),
|
||||||
|
AmneziaWg: v.AmneziaWg.copy(),
|
||||||
OpenVPN: v.OpenVPN.copy(),
|
OpenVPN: v.OpenVPN.copy(),
|
||||||
Wireguard: v.Wireguard.copy(),
|
Wireguard: v.Wireguard.copy(),
|
||||||
PMTUD: v.PMTUD.copy(),
|
PMTUD: v.PMTUD.copy(),
|
||||||
@@ -77,6 +88,7 @@ func (v *VPN) Copy() (copied VPN) {
|
|||||||
func (v *VPN) OverrideWith(other VPN) {
|
func (v *VPN) OverrideWith(other VPN) {
|
||||||
v.Type = gosettings.OverrideWithComparable(v.Type, other.Type)
|
v.Type = gosettings.OverrideWithComparable(v.Type, other.Type)
|
||||||
v.Provider.overrideWith(other.Provider)
|
v.Provider.overrideWith(other.Provider)
|
||||||
|
v.AmneziaWg.overrideWith(other.AmneziaWg)
|
||||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||||
v.Wireguard.overrideWith(other.Wireguard)
|
v.Wireguard.overrideWith(other.Wireguard)
|
||||||
v.PMTUD.overrideWith(other.PMTUD)
|
v.PMTUD.overrideWith(other.PMTUD)
|
||||||
@@ -87,6 +99,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
|||||||
func (v *VPN) setDefaults() {
|
func (v *VPN) setDefaults() {
|
||||||
v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN)
|
v.Type = gosettings.DefaultComparable(v.Type, vpn.OpenVPN)
|
||||||
v.Provider.setDefaults()
|
v.Provider.setDefaults()
|
||||||
|
v.AmneziaWg.setDefaults(v.Provider.Name)
|
||||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||||
v.Wireguard.setDefaults(v.Provider.Name)
|
v.Wireguard.setDefaults(v.Provider.Name)
|
||||||
v.PMTUD.setDefaults()
|
v.PMTUD.setDefaults()
|
||||||
@@ -103,9 +116,12 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
|||||||
|
|
||||||
node.AppendNode(v.Provider.toLinesNode())
|
node.AppendNode(v.Provider.toLinesNode())
|
||||||
|
|
||||||
if v.Type == vpn.OpenVPN {
|
switch v.Type {
|
||||||
|
case vpn.AmneziaWg:
|
||||||
|
node.AppendNode(v.AmneziaWg.toLinesNode())
|
||||||
|
case vpn.OpenVPN:
|
||||||
node.AppendNode(v.OpenVPN.toLinesNode())
|
node.AppendNode(v.OpenVPN.toLinesNode())
|
||||||
} else {
|
case vpn.Wireguard:
|
||||||
node.AppendNode(v.Wireguard.toLinesNode())
|
node.AppendNode(v.Wireguard.toLinesNode())
|
||||||
}
|
}
|
||||||
node.AppendNode(v.PMTUD.toLinesNode())
|
node.AppendNode(v.PMTUD.toLinesNode())
|
||||||
@@ -128,12 +144,18 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
|||||||
return fmt.Errorf("VPN provider: %w", err)
|
return fmt.Errorf("VPN provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = v.AmneziaWg.read(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("AmneziaWG: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = v.OpenVPN.read(r)
|
err = v.OpenVPN.read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("OpenVPN: %w", err)
|
return fmt.Errorf("OpenVPN: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.Wireguard.read(r)
|
const amneziawg = false
|
||||||
|
err = v.Wireguard.read(r, amneziawg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("wireguard: %w", err)
|
return fmt.Errorf("wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
|
||||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||||
"github.com/qdm12/gosettings"
|
"github.com/qdm12/gosettings"
|
||||||
"github.com/qdm12/gosettings/reader"
|
"github.com/qdm12/gosettings/reader"
|
||||||
@@ -42,34 +41,17 @@ type Wireguard struct {
|
|||||||
// 0 indicating to use PMTUD.
|
// 0 indicating to use PMTUD.
|
||||||
MTU *uint32 `json:"mtu"`
|
MTU *uint32 `json:"mtu"`
|
||||||
// Implementation is the Wireguard implementation to use.
|
// Implementation is the Wireguard implementation to use.
|
||||||
// It can be "auto", "userspace", "kernelspace" or "amneziawg".
|
// It can be "auto", "userspace" or "kernelspace".
|
||||||
// It defaults to "auto" and cannot be the empty string
|
// It defaults to "auto" and cannot be the empty string
|
||||||
// in the internal state.
|
// in the internal state.
|
||||||
Implementation string `json:"implementation"`
|
Implementation string `json:"implementation"`
|
||||||
// AmneziaWG contains obfuscation parameters
|
|
||||||
AmneziaWG AmneziaWg `json:"amneziawg"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||||
|
|
||||||
// Validate validates Wireguard settings.
|
// Validate validates Wireguard settings.
|
||||||
// It should only be ran if the VPN type chosen is Wireguard.
|
// It should only be ran if the VPN type chosen is Wireguard or AmneziaWg.
|
||||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) {
|
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
||||||
if !helpers.IsOneOf(vpnProvider,
|
|
||||||
providers.Airvpn,
|
|
||||||
providers.Custom,
|
|
||||||
providers.Fastestvpn,
|
|
||||||
providers.Ivpn,
|
|
||||||
providers.Mullvad,
|
|
||||||
providers.Nordvpn,
|
|
||||||
providers.Protonvpn,
|
|
||||||
providers.Surfshark,
|
|
||||||
providers.Windscribe,
|
|
||||||
) {
|
|
||||||
// do not validate for VPN provider not supporting Wireguard
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate PrivateKey
|
// Validate PrivateKey
|
||||||
if *w.PrivateKey == "" {
|
if *w.PrivateKey == "" {
|
||||||
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
||||||
@@ -138,14 +120,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
|||||||
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
validImplementations := []string{"auto", "userspace", "kernelspace", "amneziawg"}
|
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
||||||
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
||||||
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
||||||
}
|
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
||||||
|
}
|
||||||
err = w.AmneziaWG.validate()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("amneziawg settings: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -161,7 +140,6 @@ func (w *Wireguard) copy() (copied Wireguard) {
|
|||||||
Interface: w.Interface,
|
Interface: w.Interface,
|
||||||
MTU: w.MTU,
|
MTU: w.MTU,
|
||||||
Implementation: w.Implementation,
|
Implementation: w.Implementation,
|
||||||
AmneziaWG: w.AmneziaWG.copy(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +153,6 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
|||||||
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
||||||
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
||||||
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
||||||
w.AmneziaWG.overrideWith(other.AmneziaWG)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) setDefaults(vpnProvider string) {
|
func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||||
@@ -200,7 +177,6 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
|||||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||||
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
w.MTU = gosettings.DefaultPointer(w.MTU, 0)
|
||||||
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
w.Implementation = gosettings.DefaultComparable(w.Implementation, "auto")
|
||||||
w.AmneziaWG.setDefaults()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w Wireguard) String() string {
|
func (w Wireguard) String() string {
|
||||||
@@ -242,29 +218,27 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if w.Implementation != "auto" {
|
if w.Implementation != "auto" {
|
||||||
implNode := node.Appendf("Implementation: %s", w.Implementation)
|
node.Appendf("Implementation: %s", w.Implementation)
|
||||||
|
|
||||||
if w.Implementation == "amneziawg" {
|
|
||||||
implNode.AppendNode(w.AmneziaWG.toLinesNode())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) read(r *reader.Reader) (err error) {
|
func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) {
|
||||||
w.PrivateKey = r.Get("WIREGUARD_PRIVATE_KEY", reader.ForceLowercase(false))
|
prefix := "WIREGUARD"
|
||||||
w.PreSharedKey = r.Get("WIREGUARD_PRESHARED_KEY", reader.ForceLowercase(false))
|
if amneziaWG {
|
||||||
|
prefix = "AMNEZIAWG"
|
||||||
|
}
|
||||||
|
w.PrivateKey = r.Get(prefix+"_PRIVATE_KEY", reader.ForceLowercase(false))
|
||||||
|
w.PreSharedKey = r.Get(prefix+"_PRESHARED_KEY", reader.ForceLowercase(false))
|
||||||
w.Interface = r.String("VPN_INTERFACE",
|
w.Interface = r.String("VPN_INTERFACE",
|
||||||
reader.RetroKeys("WIREGUARD_INTERFACE"), reader.ForceLowercase(false))
|
reader.RetroKeys(prefix+"_INTERFACE"), reader.ForceLowercase(false))
|
||||||
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
|
|
||||||
|
|
||||||
err = w.AmneziaWG.read(r)
|
if !amneziaWG {
|
||||||
if err != nil {
|
w.Implementation = r.String("WIREGUARD_IMPLEMENTATION")
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
addressStrings := r.CSV("WIREGUARD_ADDRESSES", reader.RetroKeys("WIREGUARD_ADDRESS"))
|
addressStrings := r.CSV(prefix+"_ADDRESSES", reader.RetroKeys(prefix+"_ADDRESS"))
|
||||||
// WARNING: do not initialize w.Addresses to an empty slice
|
// WARNING: do not initialize w.Addresses to an empty slice
|
||||||
// or the defaults for nordvpn will not work.
|
// or the defaults for nordvpn will not work.
|
||||||
for _, addressString := range addressStrings {
|
for _, addressString := range addressStrings {
|
||||||
@@ -279,17 +253,17 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
|||||||
w.Addresses = append(w.Addresses, address)
|
w.Addresses = append(w.Addresses, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.AllowedIPs, err = r.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
|
w.AllowedIPs, err = r.CSVNetipPrefixes(prefix + "_ALLOWED_IPS")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err // already wrapped
|
return err // already wrapped
|
||||||
}
|
}
|
||||||
|
|
||||||
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
|
w.PersistentKeepaliveInterval, err = r.DurationPtr(prefix + "_PERSISTENT_KEEPALIVE_INTERVAL")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.MTU, err = r.Uint32Ptr("WIREGUARD_MTU")
|
w.MTU, err = r.Uint32Ptr(prefix + "_MTU")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package files
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"gopkg.in/ini.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Source) lazyLoadAmneziawgConf() AmneziawgConfig {
|
||||||
|
if s.cached.amneziawgLoaded {
|
||||||
|
return s.cached.amneziawgConf
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cached.amneziawgLoaded = true
|
||||||
|
var err error
|
||||||
|
s.cached.amneziawgConf, err = ParseAmneziawgConf(filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf"))
|
||||||
|
if err != nil {
|
||||||
|
s.warner.Warnf("skipping Amneziawg config: %s", err)
|
||||||
|
}
|
||||||
|
return s.cached.amneziawgConf
|
||||||
|
}
|
||||||
|
|
||||||
|
type AmneziawgConfig struct {
|
||||||
|
Wireguard WireguardConfig
|
||||||
|
Jc *string
|
||||||
|
Jmin *string
|
||||||
|
Jmax *string
|
||||||
|
S1 *string
|
||||||
|
S2 *string
|
||||||
|
S3 *string
|
||||||
|
S4 *string
|
||||||
|
H1 *string
|
||||||
|
H2 *string
|
||||||
|
H3 *string
|
||||||
|
H4 *string
|
||||||
|
I1 *string
|
||||||
|
I2 *string
|
||||||
|
I3 *string
|
||||||
|
I4 *string
|
||||||
|
I5 *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseAmneziawgConf(path string) (config AmneziawgConfig, err error) {
|
||||||
|
iniFile, err := ini.InsensitiveLoad(path)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return AmneziawgConfig{}, nil
|
||||||
|
}
|
||||||
|
return AmneziawgConfig{}, fmt.Errorf("loading ini from reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Wireguard, err = ParseWireguardConf(path)
|
||||||
|
if err != nil {
|
||||||
|
return AmneziawgConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaceSection, err := iniFile.GetSection("Interface")
|
||||||
|
if err != nil {
|
||||||
|
// can never happen
|
||||||
|
return AmneziawgConfig{}, fmt.Errorf("getting interface section: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Jc = getINIKeyFromSection(interfaceSection, "Jc")
|
||||||
|
config.Jmin = getINIKeyFromSection(interfaceSection, "Jmin")
|
||||||
|
config.Jmax = getINIKeyFromSection(interfaceSection, "Jmax")
|
||||||
|
config.S1 = getINIKeyFromSection(interfaceSection, "S1")
|
||||||
|
config.S2 = getINIKeyFromSection(interfaceSection, "S2")
|
||||||
|
config.S3 = getINIKeyFromSection(interfaceSection, "S3")
|
||||||
|
config.S4 = getINIKeyFromSection(interfaceSection, "S4")
|
||||||
|
config.H1 = getINIKeyFromSection(interfaceSection, "H1")
|
||||||
|
config.H2 = getINIKeyFromSection(interfaceSection, "H2")
|
||||||
|
config.H3 = getINIKeyFromSection(interfaceSection, "H3")
|
||||||
|
config.H4 = getINIKeyFromSection(interfaceSection, "H4")
|
||||||
|
config.I1 = getINIKeyFromSection(interfaceSection, "I1")
|
||||||
|
config.I2 = getINIKeyFromSection(interfaceSection, "I2")
|
||||||
|
config.I3 = getINIKeyFromSection(interfaceSection, "I3")
|
||||||
|
config.I4 = getINIKeyFromSection(interfaceSection, "I4")
|
||||||
|
config.I5 = getINIKeyFromSection(interfaceSection, "I5")
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
package files
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Source_ParseAmneziawgConf(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("no_file", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
noFile := filepath.Join(t.TempDir(), "doesnotexist")
|
||||||
|
wireguard, err := ParseAmneziawgConf(noFile)
|
||||||
|
assert.Equal(t, AmneziawgConfig{}, wireguard)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
fileContent string
|
||||||
|
amneziawg AmneziawgConfig
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"ini_load_error": {
|
||||||
|
fileContent: "invalid",
|
||||||
|
errMessage: "loading ini from reader: key-value delimiter not found: invalid",
|
||||||
|
},
|
||||||
|
"empty_file": {
|
||||||
|
errMessage: `getting interface section: section "interface" does not exist`,
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
fileContent: `
|
||||||
|
[Interface]
|
||||||
|
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
|
||||||
|
Address = 10.38.22.35/32
|
||||||
|
DNS = 193.138.218.74
|
||||||
|
Jc = 4
|
||||||
|
H1 = 721391205
|
||||||
|
I1 = <b 0x1234>
|
||||||
|
|
||||||
|
[Peer]
|
||||||
|
PresharedKey = YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g=
|
||||||
|
`,
|
||||||
|
amneziawg: AmneziawgConfig{
|
||||||
|
Wireguard: WireguardConfig{
|
||||||
|
PrivateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
|
||||||
|
PreSharedKey: ptrTo("YJ680VN+dGrdsWNjSFqZ6vvwuiNhbq502ZL3G7Q3o3g="),
|
||||||
|
Addresses: ptrTo("10.38.22.35/32"),
|
||||||
|
},
|
||||||
|
Jc: ptrTo("4"),
|
||||||
|
H1: ptrTo("721391205"),
|
||||||
|
I1: ptrTo("<b 0x1234>"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, testCase := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
configFile := filepath.Join(t.TempDir(), "awg.conf")
|
||||||
|
const permission = fs.FileMode(0o600)
|
||||||
|
err := os.WriteFile(configFile, []byte(testCase.fileContent), permission)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wireguard, err := ParseAmneziawgConf(configFile)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.amneziawg, wireguard)
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,8 @@ type Source struct {
|
|||||||
cached struct {
|
cached struct {
|
||||||
wireguardLoaded bool
|
wireguardLoaded bool
|
||||||
wireguardConf WireguardConfig
|
wireguardConf WireguardConfig
|
||||||
|
amneziawgLoaded bool
|
||||||
|
amneziawgConf AmneziawgConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,38 +71,11 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
|||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
||||||
case "wireguard_endpoint_port":
|
case "wireguard_endpoint_port":
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
||||||
case "wireguard_jc":
|
}
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
|
|
||||||
case "wireguard_jmin":
|
value, isSet, matched := s.getAmneziawgKey(key)
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
|
if matched {
|
||||||
case "wireguard_jmax":
|
return value, isSet
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
|
|
||||||
case "wireguard_s1":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
|
|
||||||
case "wireguard_s2":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
|
|
||||||
case "wireguard_s3":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
|
|
||||||
case "wireguard_s4":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
|
|
||||||
case "wireguard_h1":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
|
|
||||||
case "wireguard_h2":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
|
|
||||||
case "wireguard_h3":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
|
|
||||||
case "wireguard_h4":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
|
|
||||||
case "wireguard_i1":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
|
|
||||||
case "wireguard_i2":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
|
|
||||||
case "wireguard_i3":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
|
|
||||||
case "wireguard_i4":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
|
|
||||||
case "wireguard_i5":
|
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
value, isSet, err := ReadFromFile(path)
|
value, isSet, err := ReadFromFile(path)
|
||||||
@@ -110,6 +85,58 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
|||||||
return value, isSet
|
return value, isSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) getAmneziawgKey(key string) (value string, isSet, matched bool) {
|
||||||
|
switch key {
|
||||||
|
case "amnezia_private_key":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
|
||||||
|
case "amnezia_preshared_key":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
|
||||||
|
case "amnezia_addresses":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
|
||||||
|
case "amnezia_public_key":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
|
||||||
|
case "amnezia_endpoint_ip":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
|
||||||
|
case "amnezia_endpoint_port":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
|
||||||
|
case "amnezia_jc":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
|
||||||
|
case "amnezia_jmin":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
|
||||||
|
case "amnezia_jmax":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
|
||||||
|
case "amnezia_s1":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
|
||||||
|
case "amnezia_s2":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
|
||||||
|
case "amnezia_s3":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
|
||||||
|
case "amnezia_s4":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
|
||||||
|
case "amnezia_h1":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
|
||||||
|
case "amnezia_h2":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
|
||||||
|
case "amnezia_h3":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
|
||||||
|
case "amnezia_h4":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
|
||||||
|
case "amnezia_i1":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
|
||||||
|
case "amnezia_i2":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
|
||||||
|
case "amnezia_i3":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
|
||||||
|
case "amnezia_i4":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
|
||||||
|
case "amnezia_i5":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
|
||||||
|
default:
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
return value, isSet, true
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Source) KeyTransform(key string) string {
|
func (s *Source) KeyTransform(key string) string {
|
||||||
switch key {
|
switch key {
|
||||||
// TODO v4 remove these irregular cases
|
// TODO v4 remove these irregular cases
|
||||||
|
|||||||
@@ -25,54 +25,13 @@ func (s *Source) lazyLoadWireguardConf() WireguardConfig {
|
|||||||
return s.cached.wireguardConf
|
return s.cached.wireguardConf
|
||||||
}
|
}
|
||||||
|
|
||||||
type amneziaWgConfig struct {
|
|
||||||
Jc *string
|
|
||||||
Jmin *string
|
|
||||||
Jmax *string
|
|
||||||
S1 *string
|
|
||||||
S2 *string
|
|
||||||
S3 *string
|
|
||||||
S4 *string
|
|
||||||
H1 *string
|
|
||||||
H2 *string
|
|
||||||
H3 *string
|
|
||||||
H4 *string
|
|
||||||
I1 *string
|
|
||||||
I2 *string
|
|
||||||
I3 *string
|
|
||||||
I4 *string
|
|
||||||
I5 *string
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseWireguardAmneziaInterfaceSection(interfaceSection *ini.Section) amneziaWgConfig {
|
|
||||||
return amneziaWgConfig{
|
|
||||||
Jc: getINIKeyFromSection(interfaceSection, "Jc"),
|
|
||||||
Jmin: getINIKeyFromSection(interfaceSection, "Jmin"),
|
|
||||||
Jmax: getINIKeyFromSection(interfaceSection, "Jmax"),
|
|
||||||
S1: getINIKeyFromSection(interfaceSection, "S1"),
|
|
||||||
S2: getINIKeyFromSection(interfaceSection, "S2"),
|
|
||||||
S3: getINIKeyFromSection(interfaceSection, "S3"),
|
|
||||||
S4: getINIKeyFromSection(interfaceSection, "S4"),
|
|
||||||
H1: getINIKeyFromSection(interfaceSection, "H1"),
|
|
||||||
H2: getINIKeyFromSection(interfaceSection, "H2"),
|
|
||||||
H3: getINIKeyFromSection(interfaceSection, "H3"),
|
|
||||||
H4: getINIKeyFromSection(interfaceSection, "H4"),
|
|
||||||
I1: getINIKeyFromSection(interfaceSection, "I1"),
|
|
||||||
I2: getINIKeyFromSection(interfaceSection, "I2"),
|
|
||||||
I3: getINIKeyFromSection(interfaceSection, "I3"),
|
|
||||||
I4: getINIKeyFromSection(interfaceSection, "I4"),
|
|
||||||
I5: getINIKeyFromSection(interfaceSection, "I5"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type WireguardConfig struct {
|
type WireguardConfig struct {
|
||||||
PrivateKey *string
|
PrivateKey *string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
Addresses *string
|
Addresses *string
|
||||||
PublicKey *string
|
PublicKey *string
|
||||||
EndpointIP *string
|
EndpointIP *string
|
||||||
EndpointPort *string
|
EndpointPort *string
|
||||||
AmneziaParams amneziaWgConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`)
|
var regexINISectionNotExist = regexp.MustCompile(`^section ".+" does not exist$`)
|
||||||
@@ -89,7 +48,6 @@ func ParseWireguardConf(path string) (config WireguardConfig, err error) {
|
|||||||
interfaceSection, err := iniFile.GetSection("Interface")
|
interfaceSection, err := iniFile.GetSection("Interface")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection)
|
config.PrivateKey, config.Addresses = parseWireguardInterfaceSection(interfaceSection)
|
||||||
config.AmneziaParams = parseWireguardAmneziaInterfaceSection(interfaceSection)
|
|
||||||
} else if !regexINISectionNotExist.MatchString(err.Error()) {
|
} else if !regexINISectionNotExist.MatchString(err.Error()) {
|
||||||
// can never happen
|
// can never happen
|
||||||
return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err)
|
return WireguardConfig{}, fmt.Errorf("getting interface section: %w", err)
|
||||||
|
|||||||
@@ -97,10 +97,9 @@ func Test_parseWireguardInterfaceSection(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
iniData string
|
iniData string
|
||||||
privateKey *string
|
privateKey *string
|
||||||
addresses *string
|
addresses *string
|
||||||
amneziaParams amneziaWgConfig
|
|
||||||
}{
|
}{
|
||||||
"no_fields": {
|
"no_fields": {
|
||||||
iniData: `[Interface]`,
|
iniData: `[Interface]`,
|
||||||
@@ -116,17 +115,9 @@ PrivateKey = x
|
|||||||
[Interface]
|
[Interface]
|
||||||
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
|
PrivateKey = QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8=
|
||||||
Address = 10.38.22.35/32
|
Address = 10.38.22.35/32
|
||||||
Jc = 4
|
|
||||||
H1 = 721391205
|
|
||||||
I1 = <b 0x1234>
|
|
||||||
`,
|
`,
|
||||||
privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
|
privateKey: ptrTo("QOlCgyA/Sn/c/+YNTIEohrjm8IZV+OZ2AUFIoX20sk8="),
|
||||||
addresses: ptrTo("10.38.22.35/32"),
|
addresses: ptrTo("10.38.22.35/32"),
|
||||||
amneziaParams: amneziaWgConfig{
|
|
||||||
Jc: ptrTo("4"),
|
|
||||||
H1: ptrTo("721391205"),
|
|
||||||
I1: ptrTo("<b 0x1234>"),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,11 +131,9 @@ I1 = <b 0x1234>
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
privateKey, addresses := parseWireguardInterfaceSection(iniSection)
|
privateKey, addresses := parseWireguardInterfaceSection(iniSection)
|
||||||
amneziaWgConfig := parseWireguardAmneziaInterfaceSection(iniSection)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.privateKey, privateKey)
|
assert.Equal(t, testCase.privateKey, privateKey)
|
||||||
assert.Equal(t, testCase.addresses, addresses)
|
assert.Equal(t, testCase.addresses, addresses)
|
||||||
assert.Equal(t, testCase.amneziaParams, amneziaWgConfig)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package secrets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/configuration/sources/files"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Source) lazyLoadAmneziawgConf() files.AmneziawgConfig {
|
||||||
|
if s.cached.amneziawgLoaded {
|
||||||
|
return s.cached.amneziawgConf
|
||||||
|
}
|
||||||
|
|
||||||
|
path := os.Getenv("AMNEZIAWG_CONF_SECRETFILE")
|
||||||
|
if path == "" {
|
||||||
|
path = filepath.Join(s.rootDirectory, "amneziawg", "awg0.conf")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cached.amneziawgLoaded = true
|
||||||
|
var err error
|
||||||
|
s.cached.amneziawgConf, err = files.ParseAmneziawgConf(path)
|
||||||
|
if err != nil {
|
||||||
|
s.warner.Warnf("skipping Amneziawg config: %s", err)
|
||||||
|
}
|
||||||
|
return s.cached.amneziawgConf
|
||||||
|
}
|
||||||
@@ -15,6 +15,8 @@ type Source struct {
|
|||||||
cached struct {
|
cached struct {
|
||||||
wireguardLoaded bool
|
wireguardLoaded bool
|
||||||
wireguardConf files.WireguardConfig
|
wireguardConf files.WireguardConfig
|
||||||
|
amneziawgLoaded bool
|
||||||
|
amneziawgConf files.AmneziawgConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,26 +62,26 @@ func (s *Source) Get(key string) (value string, isSet bool) {
|
|||||||
s.warner.Warnf("skipping %s: parsing PEM: %s", path, err)
|
s.warner.Warnf("skipping %s: parsing PEM: %s", path, err)
|
||||||
}
|
}
|
||||||
return value, isSet
|
return value, isSet
|
||||||
case "wireguard_private_key":
|
case "wireguard_private_key", "amneziawg_private_key":
|
||||||
privateKey := s.lazyLoadWireguardConf().PrivateKey
|
privateKey := s.lazyLoadWireguardConf().PrivateKey
|
||||||
if privateKey != nil {
|
if privateKey != nil {
|
||||||
return *privateKey, true
|
return *privateKey, true
|
||||||
} // else continue to read from individual secret file
|
} // else continue to read from individual secret file
|
||||||
case "wireguard_preshared_key":
|
case "wireguard_preshared_key", "amneziawg_preshared_key":
|
||||||
preSharedKey := s.lazyLoadWireguardConf().PreSharedKey
|
preSharedKey := s.lazyLoadWireguardConf().PreSharedKey
|
||||||
if preSharedKey != nil {
|
if preSharedKey != nil {
|
||||||
return *preSharedKey, true
|
return *preSharedKey, true
|
||||||
} // else continue to read from individual secret file
|
} // else continue to read from individual secret file
|
||||||
case "wireguard_addresses":
|
case "wireguard_addresses", "amneziawg_addresses":
|
||||||
addresses := s.lazyLoadWireguardConf().Addresses
|
addresses := s.lazyLoadWireguardConf().Addresses
|
||||||
if addresses != nil {
|
if addresses != nil {
|
||||||
return *addresses, true
|
return *addresses, true
|
||||||
} // else continue to read from individual secret file
|
} // else continue to read from individual secret file
|
||||||
case "wireguard_public_key":
|
case "wireguard_public_key", "amneziawg_public_key":
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().PublicKey)
|
return strPtrToStringIsSet(s.lazyLoadWireguardConf().PublicKey)
|
||||||
case "wireguard_endpoint_ip":
|
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointIP)
|
||||||
case "wireguard_endpoint_port":
|
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
|
||||||
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
return strPtrToStringIsSet(s.lazyLoadWireguardConf().EndpointPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,38 +114,50 @@ func (s *Source) KeyTransform(key string) string {
|
|||||||
|
|
||||||
func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) {
|
func (s *Source) getAmneziaWg(key string) (value string, isSet, matched bool) {
|
||||||
switch key {
|
switch key {
|
||||||
case "wireguard_jc":
|
case "amneziawg_private_key":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jc)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PrivateKey)
|
||||||
case "wireguard_jmin":
|
case "amneziawg_preshared_key":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmin)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PreSharedKey)
|
||||||
case "wireguard_jmax":
|
case "wireguard_addresses", "amneziawg_addresses":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.Jmax)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.Addresses)
|
||||||
case "wireguard_s1":
|
case "wireguard_public_key", "amneziawg_public_key":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S1)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.PublicKey)
|
||||||
case "wireguard_s2":
|
case "wireguard_endpoint_ip", "amneziawg_endpoint_ip":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S2)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointIP)
|
||||||
case "wireguard_s3":
|
case "wireguard_endpoint_port", "amneziawg_endpoint_port":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S3)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Wireguard.EndpointPort)
|
||||||
case "wireguard_s4":
|
case "amneziawg_jc":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.S4)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jc)
|
||||||
case "wireguard_h1":
|
case "amneziawg_jmin":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H1)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmin)
|
||||||
case "wireguard_h2":
|
case "amneziawg_jmax":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H2)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().Jmax)
|
||||||
case "wireguard_h3":
|
case "amneziawg_s1":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H3)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S1)
|
||||||
case "wireguard_h4":
|
case "amneziawg_s2":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.H4)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S2)
|
||||||
case "wireguard_i1":
|
case "amneziawg_s3":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I1)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S3)
|
||||||
case "wireguard_i2":
|
case "amneziawg_s4":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I2)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().S4)
|
||||||
case "wireguard_i3":
|
case "amneziawg_h1":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I3)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H1)
|
||||||
case "wireguard_i4":
|
case "amneziawg_h2":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I4)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H2)
|
||||||
case "wireguard_i5":
|
case "amneziawg_h3":
|
||||||
value, isSet = strPtrToStringIsSet(s.lazyLoadWireguardConf().AmneziaParams.I5)
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H3)
|
||||||
|
case "amneziawg_h4":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().H4)
|
||||||
|
case "amneziawg_i1":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I1)
|
||||||
|
case "amneziawg_i2":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I2)
|
||||||
|
case "amneziawg_i3":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I3)
|
||||||
|
case "amneziawg_i4":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I4)
|
||||||
|
case "amneziawg_i5":
|
||||||
|
value, isSet = strPtrToStringIsSet(s.lazyLoadAmneziawgConf().I5)
|
||||||
default:
|
default:
|
||||||
return "", false, false
|
return "", false, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package vpn
|
package vpn
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
AmneziaWg = "amneziawg"
|
||||||
OpenVPN = "openvpn"
|
OpenVPN = "openvpn"
|
||||||
Wireguard = "wireguard"
|
Wireguard = "wireguard"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package vpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/amneziawg"
|
||||||
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/provider"
|
||||||
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
|
"github.com/qdm12/gosettings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupAmneziaWg sets AmneziaWG up using the configurators and settings given.
|
||||||
|
func setupAmneziaWg(ctx context.Context, netlinker NetLinker,
|
||||||
|
fw Firewall, providerConf provider.Provider,
|
||||||
|
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||||
|
amneziawger *amneziawg.Amneziawg, connection models.Connection, err error,
|
||||||
|
) {
|
||||||
|
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||||
|
if err != nil {
|
||||||
|
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
amneziaWGSettings := buildAmneziaWgSettings(connection, settings.AmneziaWg, ipv6Supported)
|
||||||
|
|
||||||
|
logger.Debug("Amneziawg server public key: " + amneziaWGSettings.Wireguard.PublicKey)
|
||||||
|
logger.Debug("Amneziawg client private key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PrivateKey))
|
||||||
|
logger.Debug("Amneziawg pre-shared key: " + gosettings.ObfuscateKey(amneziaWGSettings.Wireguard.PreSharedKey))
|
||||||
|
|
||||||
|
amneziawger, err = amneziawg.New(amneziaWGSettings, netlinker, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, models.Connection{}, fmt.Errorf("creating amneziawg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return amneziawger, connection, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAmneziaWgSettings(connection models.Connection,
|
||||||
|
userSettings settings.AmneziaWg, ipv6Supported bool,
|
||||||
|
) amneziawg.Settings {
|
||||||
|
return amneziawg.Settings{
|
||||||
|
Wireguard: buildWireguardSettings(connection, userSettings.Wireguard, ipv6Supported),
|
||||||
|
JunkPacketCount: *userSettings.JunkPacketCount,
|
||||||
|
JunkPacketMin: *userSettings.JunkPacketMin,
|
||||||
|
JunkPacketMax: *userSettings.JunkPacketMax,
|
||||||
|
PaddingS1: *userSettings.PaddingS1,
|
||||||
|
PaddingS2: *userSettings.PaddingS2,
|
||||||
|
PaddingS3: *userSettings.PaddingS3,
|
||||||
|
PaddingS4: *userSettings.PaddingS4,
|
||||||
|
HeaderH1: *userSettings.HeaderH1,
|
||||||
|
HeaderH2: *userSettings.HeaderH2,
|
||||||
|
HeaderH3: *userSettings.HeaderH3,
|
||||||
|
HeaderH4: *userSettings.HeaderH4,
|
||||||
|
InitPacketI1: *userSettings.InitPacketI1,
|
||||||
|
InitPacketI2: *userSettings.InitPacketI2,
|
||||||
|
InitPacketI3: *userSettings.InitPacketI3,
|
||||||
|
InitPacketI4: *userSettings.InitPacketI4,
|
||||||
|
InitPacketI5: *userSettings.InitPacketI5,
|
||||||
|
}
|
||||||
|
}
|
||||||
+9
-2
@@ -33,14 +33,21 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
var connection models.Connection
|
var connection models.Connection
|
||||||
var err error
|
var err error
|
||||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||||
if settings.Type == vpn.OpenVPN {
|
switch settings.Type {
|
||||||
|
case vpn.AmneziaWg:
|
||||||
|
vpnInterface = settings.AmneziaWg.Wireguard.Interface
|
||||||
|
vpnRunner, connection, err = setupAmneziaWg(ctx, l.netLinker, l.fw,
|
||||||
|
providerConf, settings, l.ipv6Supported, subLogger)
|
||||||
|
case vpn.OpenVPN:
|
||||||
vpnInterface = settings.OpenVPN.Interface
|
vpnInterface = settings.OpenVPN.Interface
|
||||||
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.cmder, subLogger)
|
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.cmder, subLogger)
|
||||||
} else { // Wireguard
|
case vpn.Wireguard:
|
||||||
vpnInterface = settings.Wireguard.Interface
|
vpnInterface = settings.Wireguard.Interface
|
||||||
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||||
providerConf, settings, l.ipv6Supported, subLogger)
|
providerConf, settings, l.ipv6Supported, subLogger)
|
||||||
|
default:
|
||||||
|
panic("vpn type not implemented: " + settings.Type)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.crashed(ctx, err)
|
l.crashed(ctx, err)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/pmtud"
|
"github.com/qdm12/gluetun/internal/pmtud"
|
||||||
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
pconstants "github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||||
@@ -48,6 +49,14 @@ type tunnelUpPMTUDData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||||
|
switch vpnType := l.GetSettings().Type; vpnType {
|
||||||
|
case vpn.Wireguard, vpn.AmneziaWg:
|
||||||
|
l.logger.Infof("%s setup is complete. "+
|
||||||
|
"Note %s is a silent protocol and it may or may not work, without giving any error message. "+
|
||||||
|
"Typically i/o timeout errors indicate the %s connection is not working.",
|
||||||
|
vpnType, vpnType, vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
l.client.CloseIdleConnections()
|
l.client.CloseIdleConnections()
|
||||||
|
|
||||||
for _, vpnPort := range l.vpnInputPorts {
|
for _, vpnPort := range l.vpnInputPorts {
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ func buildWireguardSettings(connection models.Connection,
|
|||||||
settings.PreSharedKey = *userSettings.PreSharedKey
|
settings.PreSharedKey = *userSettings.PreSharedKey
|
||||||
settings.InterfaceName = userSettings.Interface
|
settings.InterfaceName = userSettings.Interface
|
||||||
settings.Implementation = userSettings.Implementation
|
settings.Implementation = userSettings.Implementation
|
||||||
settings.AmneziaWG = buildAmneziaWgSettings(userSettings.AmneziaWG)
|
|
||||||
if *userSettings.MTU > 0 {
|
if *userSettings.MTU > 0 {
|
||||||
settings.MTU = *userSettings.MTU
|
settings.MTU = *userSettings.MTU
|
||||||
} else {
|
} else {
|
||||||
@@ -91,24 +90,3 @@ func buildWireguardSettings(connection models.Connection,
|
|||||||
|
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAmneziaWgSettings(s settings.AmneziaWg) wireguard.AmneziaSettings {
|
|
||||||
return wireguard.AmneziaSettings{
|
|
||||||
JunkPacketCount: *s.JunkPacketCount,
|
|
||||||
JunkPacketMin: *s.JunkPacketMin,
|
|
||||||
JunkPacketMax: *s.JunkPacketMax,
|
|
||||||
PaddingS1: *s.PaddingS1,
|
|
||||||
PaddingS2: *s.PaddingS2,
|
|
||||||
PaddingS3: *s.PaddingS3,
|
|
||||||
PaddingS4: *s.PaddingS4,
|
|
||||||
HeaderH1: *s.HeaderH1,
|
|
||||||
HeaderH2: *s.HeaderH2,
|
|
||||||
HeaderH3: *s.HeaderH3,
|
|
||||||
HeaderH4: *s.HeaderH4,
|
|
||||||
InitPacketI1: *s.InitPacketI1,
|
|
||||||
InitPacketI2: *s.InitPacketI2,
|
|
||||||
InitPacketI3: *s.InitPacketI3,
|
|
||||||
InitPacketI4: *s.InitPacketI4,
|
|
||||||
InitPacketI5: *s.InitPacketI5,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -40,24 +40,6 @@ func Test_buildWireguardSettings(t *testing.T) {
|
|||||||
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||||
Interface: "wg1",
|
Interface: "wg1",
|
||||||
MTU: ptrTo(uint32(1000)),
|
MTU: ptrTo(uint32(1000)),
|
||||||
AmneziaWG: settings.AmneziaWg{
|
|
||||||
JunkPacketCount: ptrTo(uint16(1)),
|
|
||||||
JunkPacketMin: ptrTo(uint16(0)),
|
|
||||||
JunkPacketMax: ptrTo(uint16(0)),
|
|
||||||
PaddingS1: ptrTo(uint16(0)),
|
|
||||||
PaddingS2: ptrTo(uint16(0)),
|
|
||||||
PaddingS3: ptrTo(uint16(0)),
|
|
||||||
PaddingS4: ptrTo(uint16(0)),
|
|
||||||
HeaderH1: ptrTo("x"),
|
|
||||||
HeaderH2: ptrTo(""),
|
|
||||||
HeaderH3: ptrTo(""),
|
|
||||||
HeaderH4: ptrTo(""),
|
|
||||||
InitPacketI1: ptrTo(""),
|
|
||||||
InitPacketI2: ptrTo(""),
|
|
||||||
InitPacketI3: ptrTo(""),
|
|
||||||
InitPacketI4: ptrTo(""),
|
|
||||||
InitPacketI5: ptrTo(""),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
ipv6Supported: false,
|
ipv6Supported: false,
|
||||||
settings: wireguard.Settings{
|
settings: wireguard.Settings{
|
||||||
@@ -76,10 +58,6 @@ func Test_buildWireguardSettings(t *testing.T) {
|
|||||||
RulePriority: 101,
|
RulePriority: 101,
|
||||||
IPv6: ptrTo(false),
|
IPv6: ptrTo(false),
|
||||||
MTU: 1000,
|
MTU: 1000,
|
||||||
AmneziaWG: wireguard.AmneziaSettings{
|
|
||||||
JunkPacketCount: 1,
|
|
||||||
HeaderH1: "x",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,15 +5,16 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addAddresses(linkIndex uint32,
|
func AddAddresses(linkIndex uint32,
|
||||||
addresses []netip.Prefix,
|
addresses []netip.Prefix, ipv6 bool,
|
||||||
|
netlink NetLinker,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
if !*w.settings.IPv6 && address.Addr().Is6() {
|
if !ipv6 && address.Addr().Is6() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.netlink.AddrReplace(linkIndex, address)
|
err = netlink.AddrReplace(linkIndex, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
return fmt.Errorf("%w: when adding address %s to link with index %d",
|
||||||
err, address, linkIndex)
|
err, address, linkIndex)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Wireguard_addAddresses(t *testing.T) {
|
func Test_AddAddresses(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
||||||
@@ -19,15 +19,17 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
linkIndex uint32
|
linkIndex uint32
|
||||||
addrs []netip.Prefix
|
addrs []netip.Prefix
|
||||||
wgBuilder func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard
|
ipv6 bool
|
||||||
err error
|
netlinkBuilder func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker
|
||||||
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
linkIndex: 1,
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
ipv6: true,
|
||||||
|
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(linkIndex, ipNetOne).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
@@ -35,35 +37,27 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(linkIndex, ipNetTwo).
|
AddrReplace(linkIndex, ipNetTwo).
|
||||||
Return(nil).After(firstCall)
|
Return(nil).After(firstCall)
|
||||||
return &Wireguard{
|
return netLinker
|
||||||
netlink: netLinker,
|
|
||||||
settings: Settings{
|
|
||||||
IPv6: ptrTo(true),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"first add error": {
|
"first add error": {
|
||||||
linkIndex: 1,
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
ipv6: true,
|
||||||
|
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(linkIndex, ipNetOne).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
Return(errDummy)
|
Return(errDummy)
|
||||||
return &Wireguard{
|
return netLinker
|
||||||
netlink: netLinker,
|
|
||||||
settings: Settings{
|
|
||||||
IPv6: ptrTo(true),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
err: errors.New("dummy: when adding address 1.2.3.4/32 to link with index 1"),
|
||||||
},
|
},
|
||||||
"second add error": {
|
"second add error": {
|
||||||
linkIndex: 1,
|
linkIndex: 1,
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *Wireguard {
|
ipv6: true,
|
||||||
|
netlinkBuilder: func(ctrl *gomock.Controller, linkIndex uint32) *MockNetLinker {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(linkIndex, ipNetOne).
|
AddrReplace(linkIndex, ipNetOne).
|
||||||
@@ -71,23 +65,14 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(linkIndex, ipNetTwo).
|
AddrReplace(linkIndex, ipNetTwo).
|
||||||
Return(errDummy).After(firstCall)
|
Return(errDummy).After(firstCall)
|
||||||
return &Wireguard{
|
return netLinker
|
||||||
netlink: netLinker,
|
|
||||||
settings: Settings{
|
|
||||||
IPv6: ptrTo(true),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
err: errors.New("dummy: when adding address ::1234/64 to link with index 1"),
|
||||||
},
|
},
|
||||||
"ignore IPv6": {
|
"ignore IPv6": {
|
||||||
addrs: []netip.Prefix{ipNetTwo},
|
addrs: []netip.Prefix{ipNetTwo},
|
||||||
wgBuilder: func(_ *gomock.Controller, _ uint32) *Wireguard {
|
netlinkBuilder: func(_ *gomock.Controller, _ uint32) *MockNetLinker {
|
||||||
return &Wireguard{
|
return NewMockNetLinker(nil)
|
||||||
settings: Settings{
|
|
||||||
IPv6: ptrTo(false),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -97,9 +82,9 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
wg := testCase.wgBuilder(ctrl, testCase.linkIndex)
|
netlink := testCase.netlinkBuilder(ctrl, testCase.linkIndex)
|
||||||
|
|
||||||
err := wg.addAddresses(testCase.linkIndex, testCase.addrs)
|
err := AddAddresses(testCase.linkIndex, testCase.addrs, testCase.ipv6, netlink)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package wireguard
|
|
||||||
|
|
||||||
import "sort"
|
|
||||||
|
|
||||||
type closer struct {
|
|
||||||
operation string
|
|
||||||
step step
|
|
||||||
close func() error
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type closers []closer
|
|
||||||
|
|
||||||
func (c *closers) add(operation string, step step,
|
|
||||||
closeFunc func() error,
|
|
||||||
) {
|
|
||||||
closer := closer{
|
|
||||||
operation: operation,
|
|
||||||
step: step,
|
|
||||||
close: closeFunc,
|
|
||||||
}
|
|
||||||
*c = append(*c, closer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *closers) cleanup(logger Logger) {
|
|
||||||
closers := *c
|
|
||||||
|
|
||||||
sort.Slice(closers, func(i, j int) bool {
|
|
||||||
return closers[i].step < closers[j].step
|
|
||||||
})
|
|
||||||
|
|
||||||
for i, closer := range closers {
|
|
||||||
if closer.closed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
closers[i].closed = true
|
|
||||||
logger.Debug(closer.operation + "...")
|
|
||||||
err := closer.close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type step int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// stepOne closes the wireguard controller client,
|
|
||||||
// and removes the IP rule.
|
|
||||||
stepOne step = iota
|
|
||||||
// stepTwo closes the UAPI listener.
|
|
||||||
stepTwo
|
|
||||||
// stepThree closes the UAPI file.
|
|
||||||
stepThree
|
|
||||||
// stepFour shuts down the Wireguard link.
|
|
||||||
stepFour
|
|
||||||
// stepFive removes the Wireguard link.
|
|
||||||
stepFive
|
|
||||||
// stepSix closes the Wireguard device.
|
|
||||||
stepSix
|
|
||||||
// stepSeven closes the bind connection and the TUN device file.
|
|
||||||
stepSeven
|
|
||||||
)
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package wireguard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_closers(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
|
|
||||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
|
||||||
var (
|
|
||||||
AErr error
|
|
||||||
BErr = errors.New("B failed")
|
|
||||||
CErr = errors.New("C failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
var closers closers
|
|
||||||
closers.add("closing A", stepFive, func() error {
|
|
||||||
ACloseCalled = true
|
|
||||||
return AErr
|
|
||||||
})
|
|
||||||
|
|
||||||
closers.add("closing B", stepThree, func() error {
|
|
||||||
BCloseCalled = true
|
|
||||||
return BErr
|
|
||||||
})
|
|
||||||
|
|
||||||
closers.add("closing C", stepTwo, func() error {
|
|
||||||
CCloseCalled = true
|
|
||||||
return CErr
|
|
||||||
})
|
|
||||||
|
|
||||||
logger := NewMockLogger(ctrl)
|
|
||||||
prevCall := logger.EXPECT().Debug("closing C...")
|
|
||||||
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
|
|
||||||
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
|
|
||||||
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
|
|
||||||
logger.EXPECT().Debug("closing A...").After(prevCall)
|
|
||||||
|
|
||||||
closers.cleanup(logger)
|
|
||||||
|
|
||||||
closers.cleanup(logger) // run twice should not close already closed
|
|
||||||
|
|
||||||
for _, closer := range closers {
|
|
||||||
assert.True(t, closer.closed)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, ACloseCalled)
|
|
||||||
assert.True(t, BCloseCalled)
|
|
||||||
assert.True(t, CCloseCalled)
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
package wireguard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type tunDevice interface {
|
|
||||||
Close() error
|
|
||||||
Name() (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type bind interface {
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type userspaceDevice interface {
|
|
||||||
Close()
|
|
||||||
Wait() chan struct{}
|
|
||||||
IpcHandle(net.Conn)
|
|
||||||
IpcSet(string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type userSpaceBackend struct {
|
|
||||||
createTun func(string, int) (tunDevice, error)
|
|
||||||
createBind func() bind
|
|
||||||
createDevice func(tunDevice, bind, Logger) userspaceDevice
|
|
||||||
preStart func(userspaceDevice, Settings) error
|
|
||||||
}
|
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
func ConfigureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||||
deviceConfig, err := makeDeviceConfig(settings)
|
deviceConfig, err := makeDeviceConfig(settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("making device configuration: %w", err)
|
return fmt.Errorf("making device configuration: %w", err)
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
)
|
||||||
|
|
||||||
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
||||||
|
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
@@ -7,5 +11,16 @@ type Logger interface {
|
|||||||
Debugf(format string, args ...interface{})
|
Debugf(format string, args ...interface{})
|
||||||
Info(s string)
|
Info(s string)
|
||||||
Error(s string)
|
Error(s string)
|
||||||
Errorf(format string, args ...interface{})
|
Erroer
|
||||||
|
}
|
||||||
|
|
||||||
|
type Erroer interface {
|
||||||
|
Errorf(format string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
|
||||||
|
return &device.Logger{
|
||||||
|
Verbosef: logger.Debugf,
|
||||||
|
Errorf: logger.Errorf,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_makeDeviceLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
|
||||||
|
deviceLogger := makeDeviceLogger(logger)
|
||||||
|
|
||||||
|
logger.EXPECT().Debugf("test %d", 1)
|
||||||
|
deviceLogger.Verbosef("test %d", 1)
|
||||||
|
|
||||||
|
logger.EXPECT().Errorf("test %d", 2)
|
||||||
|
deviceLogger.Errorf("test %d", 2)
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ func (n noopDebugLogger) Error(_ string) {}
|
|||||||
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
func (n noopDebugLogger) Errorf(_ string, _ ...any) {}
|
||||||
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
func (n noopDebugLogger) Patch(_ ...log.Option) {}
|
||||||
|
|
||||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
func Test_AddAddresses_Integration(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
netlinker := netlink.New(&noopDebugLogger{})
|
netlinker := netlink.New(&noopDebugLogger{})
|
||||||
@@ -55,7 +55,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
|
|
||||||
const addIterations = 2 // initial + replace
|
const addIterations = 2 // initial + replace
|
||||||
for range addIterations {
|
for range addIterations {
|
||||||
err = wg.addAddresses(link.Index, addresses)
|
err = AddAddresses(link.Index, addresses, *wg.settings.IPv6, wg.netlink)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
ipPrefixes, err := netlinker.AddrList(link.Index, netlink.FamilyAll)
|
||||||
@@ -67,22 +67,19 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_netlink_Wireguard_addRule(t *testing.T) {
|
func Test_AddRule_Integration(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
netlinker := netlink.New(&noopDebugLogger{})
|
logger := &noopDebugLogger{}
|
||||||
wg := &Wireguard{
|
netlinker := netlink.New(logger)
|
||||||
netlink: netlinker,
|
|
||||||
logger: &noopDebugLogger{},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unique combination for this test
|
// Unique combination for this test
|
||||||
const rulePriority uint32 = 10000
|
const rulePriority uint32 = 10000
|
||||||
const firewallMark uint32 = 12345
|
const firewallMark uint32 = 12345
|
||||||
const family = netlink.FamilyV4
|
const family = netlink.FamilyV4
|
||||||
|
|
||||||
cleanup, err := wg.addRule(rulePriority,
|
cleanup, err := AddRule(rulePriority,
|
||||||
firewallMark, family)
|
firewallMark, family, netlinker, logger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := cleanup()
|
err := cleanup()
|
||||||
@@ -110,8 +107,8 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
|||||||
require.True(t, ruleFound)
|
require.True(t, ruleFound)
|
||||||
|
|
||||||
// Existing rule cannot be added
|
// Existing rule cannot be added
|
||||||
nilCleanup, err := wg.addRule(rulePriority,
|
nilCleanup, err := AddRule(rulePriority,
|
||||||
firewallMark, family)
|
firewallMark, family, netlinker, logger)
|
||||||
if nilCleanup != nil {
|
if nilCleanup != nil {
|
||||||
_ = nilCleanup() // in case it succeeds
|
_ = nilCleanup() // in case it succeeds
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,17 +8,17 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
func AddRoutes(linkIndex uint32, destinations []netip.Prefix,
|
||||||
firewallMark uint32,
|
firewallMark uint32, netlinker NetLinker, logger Erroer,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for _, dst := range destinations {
|
for _, dst := range destinations {
|
||||||
err = w.addRoute(linkIndex, dst, firewallMark)
|
err = addRoute(linkIndex, dst, firewallMark, netlinker)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
|
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
|
||||||
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||||
"Ignoring and continuing execution; "+
|
"Ignoring and continuing execution; "+
|
||||||
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
||||||
"Full error string: %s", err)
|
"Full error string: %s", err)
|
||||||
@@ -29,8 +29,8 @@ func (w *Wireguard) addRoutes(linkIndex uint32, destinations []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
func addRoute(linkIndex uint32, dst netip.Prefix,
|
||||||
firewallMark uint32,
|
firewallMark uint32, netlinker NetLinker,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
family := netlink.FamilyV4
|
family := netlink.FamilyV4
|
||||||
if dst.Addr().Is6() {
|
if dst.Addr().Is6() {
|
||||||
@@ -46,7 +46,7 @@ func (w *Wireguard) addRoute(linkIndex uint32, dst netip.Prefix,
|
|||||||
Proto: netlink.ProtoStatic,
|
Proto: netlink.ProtoStatic,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.netlink.RouteAdd(route)
|
err = netlinker.RouteAdd(route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"adding route for link with index %d, destination %s and table %d: %w",
|
"adding route for link with index %d, destination %s and table %d: %w",
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Wireguard_addRoute(t *testing.T) {
|
func Test_addRoute(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
const linkIndex = 88
|
const linkIndex = 88
|
||||||
@@ -62,15 +62,11 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
wg := Wireguard{
|
|
||||||
netlink: netLinker,
|
|
||||||
}
|
|
||||||
|
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
RouteAdd(testCase.expectedRoute).
|
RouteAdd(testCase.expectedRoute).
|
||||||
Return(testCase.routeAddErr)
|
Return(testCase.routeAddErr)
|
||||||
|
|
||||||
err := wg.addRoute(linkIndex, testCase.dst, firewallMark)
|
err := addRoute(linkIndex, testCase.dst, firewallMark, netLinker)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
func AddRule(rulePriority, firewallMark uint32, family uint8,
|
||||||
family uint8,
|
netlinker NetLinker, logger Logger,
|
||||||
) (cleanup func() error, err error) {
|
) (cleanup func() error, err error) {
|
||||||
rule := netlink.Rule{
|
rule := netlink.Rule{
|
||||||
Priority: &rulePriority,
|
Priority: &rulePriority,
|
||||||
@@ -18,16 +18,16 @@ func (w *Wireguard) addRule(rulePriority, firewallMark uint32,
|
|||||||
Flags: netlink.FlagInvert,
|
Flags: netlink.FlagInvert,
|
||||||
Action: netlink.ActionToTable,
|
Action: netlink.ActionToTable,
|
||||||
}
|
}
|
||||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
if err := netlinker.RuleAdd(rule); err != nil {
|
||||||
if strings.HasSuffix(err.Error(), "file exists") {
|
if strings.HasSuffix(err.Error(), "file exists") {
|
||||||
w.logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
logger.Info("if you are using Kubernetes, this may fix the error below: " +
|
||||||
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
|
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/kubernetes.md#adding-ipv6-rule--file-exists")
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("adding %s: %w", rule, err)
|
return nil, fmt.Errorf("adding %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup = func() error {
|
cleanup = func() error {
|
||||||
err := w.netlink.RuleDel(rule)
|
err := netlinker.RuleDel(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("deleting rule %s: %w", rule, err)
|
return fmt.Errorf("deleting rule %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Wireguard_addRule(t *testing.T) {
|
func Test_AddRule(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
const rulePriority uint32 = 987
|
const rulePriority uint32 = 987
|
||||||
@@ -68,13 +68,11 @@ func Test_Wireguard_addRule(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
wg := Wireguard{
|
|
||||||
netlink: netLinker,
|
|
||||||
}
|
|
||||||
|
|
||||||
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
||||||
Return(testCase.ruleAddErr)
|
Return(testCase.ruleAddErr)
|
||||||
cleanup, err := wg.addRule(rulePriority, firewallMark, family)
|
cleanup, err := AddRule(rulePriority, firewallMark, family,
|
||||||
|
netLinker, nil)
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||||
|
|||||||
+85
-89
@@ -6,39 +6,33 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/cleanup"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrDetectKernel = errors.New("cannot detect Kernel support")
|
errKernelSupport = errors.New("kernel does not support Wireguard")
|
||||||
ErrCreateTun = errors.New("cannot create TUN device")
|
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||||
ErrAddLink = errors.New("cannot add Wireguard link")
|
errDeviceWaited = errors.New("device waited for")
|
||||||
ErrFindLink = errors.New("cannot find link")
|
|
||||||
ErrFindDevice = errors.New("cannot find Wireguard device")
|
|
||||||
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
|
|
||||||
ErrWgctrlOpen = errors.New("cannot open wgctrl")
|
|
||||||
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
|
|
||||||
ErrAddAddress = errors.New("cannot add address to wireguard interface")
|
|
||||||
ErrConfigure = errors.New("cannot configure wireguard interface")
|
|
||||||
ErrDeviceInfo = errors.New("cannot get wireguard device information")
|
|
||||||
ErrIfaceUp = errors.New("cannot set the interface to UP")
|
|
||||||
ErrRouteAdd = errors.New("cannot add route for interface")
|
|
||||||
ErrDeviceWaited = errors.New("device waited for")
|
|
||||||
ErrKernelSupport = errors.New("kernel does not support Wireguard")
|
|
||||||
ErrAmneziaConfigure = errors.New("cannot configure AmneziaWG")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Run runs the wireguard interface and waits until the context is done, then it cleans up the
|
||||||
|
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||||
|
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||||
|
// is done. It sends a signal to ready when the setup is complete and the interface is ready to use.
|
||||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||||
kernelSupported, err := w.netlink.IsWireguardSupported()
|
kernelSupported, err := w.netlink.IsWireguardSupported()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
|
waitError <- fmt.Errorf("detecting wireguard kernel support: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userspaceBackend := defaultUserSpaceBackend()
|
setupFunction := setupUserSpace
|
||||||
setupFunction := setupUserSpaceCommon
|
|
||||||
switch w.settings.Implementation {
|
switch w.settings.Implementation {
|
||||||
case "auto": //nolint:goconst
|
case "auto": //nolint:goconst
|
||||||
if !kernelSupported {
|
if !kernelSupported {
|
||||||
@@ -50,95 +44,105 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
case "userspace":
|
case "userspace":
|
||||||
case "kernelspace":
|
case "kernelspace":
|
||||||
if !kernelSupported {
|
if !kernelSupported {
|
||||||
waitError <- fmt.Errorf("%w", ErrKernelSupport)
|
waitError <- fmt.Errorf("%w", errKernelSupport)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setupFunction = setupKernelSpace
|
setupFunction = setupKernelSpace
|
||||||
case "amneziawg":
|
|
||||||
userspaceBackend = amneziaUserSpaceBackend()
|
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
|
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setup := func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||||
|
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||||
|
) {
|
||||||
|
return setupFunction(ctx,
|
||||||
|
w.settings.InterfaceName, w.netlink, w.settings.MTU, cleanups, w.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
Run(ctx, waitError, ready, setup, w.settings, w.netlink, w.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Run(ctx context.Context, waitError chan<- error, ready chan<- struct{},
|
||||||
|
setup func(ctx context.Context, cleanups *cleanup.Cleanups) (
|
||||||
|
linkIndex uint32, waitAndCleanup func() error, err error),
|
||||||
|
settings Settings, netlinker NetLinker, logger Logger,
|
||||||
|
) {
|
||||||
client, err := wgctrl.New()
|
client, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
|
waitError <- fmt.Errorf("opening wgctrl: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var closers closers
|
var cleanups cleanup.Cleanups
|
||||||
closers.add("closing controller client", stepOne, client.Close)
|
cleanups.Add("closing controller client", 1, client.Close)
|
||||||
|
|
||||||
defer closers.cleanup(w.logger)
|
defer cleanups.Cleanup(logger)
|
||||||
|
|
||||||
linkIndex, waitAndCleanup, err := setupFunction(ctx,
|
linkIndex, waitAndCleanup, err := setup(ctx, &cleanups)
|
||||||
w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger, w.settings, userspaceBackend)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- err
|
waitError <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.addAddresses(linkIndex, w.settings.Addresses)
|
err = AddAddresses(linkIndex, settings.Addresses, *settings.IPv6, netlinker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
waitError <- fmt.Errorf("adding addresses to interface: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.logger.Info("Connecting to " + w.settings.Endpoint.String())
|
logger.Info("Connecting to " + settings.Endpoint.String())
|
||||||
err = configureDevice(client, w.settings)
|
err = ConfigureDevice(client, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
|
waitError <- fmt.Errorf("configuring interface: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.netlink.LinkSetUp(linkIndex)
|
err = netlinker.LinkSetUp(linkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
waitError <- fmt.Errorf("setting the interface UP: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
closers.add("shutting down link", stepFour, func() error {
|
cleanups.Add("shutting down link", 4, func() error {
|
||||||
return w.netlink.LinkSetDown(linkIndex)
|
return netlinker.LinkSetDown(linkIndex)
|
||||||
})
|
})
|
||||||
|
|
||||||
err = w.addRoutes(linkIndex, w.settings.AllowedIPs, w.settings.FirewallMark)
|
err = AddRoutes(linkIndex, settings.AllowedIPs, settings.FirewallMark,
|
||||||
|
netlinker, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
waitError <- fmt.Errorf("adding routes for interface: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *w.settings.IPv6 {
|
if *settings.IPv6 {
|
||||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||||
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
ruleCleanup6, err := AddRule(settings.RulePriority,
|
||||||
w.settings.FirewallMark, netlink.FamilyV6)
|
settings.FirewallMark, netlink.FamilyV6,
|
||||||
|
netlinker, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
|
cleanups.Add("removing IPv6 rule", 1, ruleCleanup6)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
ruleCleanup, err := AddRule(settings.RulePriority,
|
||||||
w.settings.FirewallMark, netlink.FamilyV4)
|
settings.FirewallMark, netlink.FamilyV4,
|
||||||
|
netlinker, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
waitError <- fmt.Errorf("adding IPv4 rule: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("removing IPv4 rule", stepOne, ruleCleanup)
|
cleanups.Add("removing IPv4 rule", 1, ruleCleanup)
|
||||||
w.logger.Info("Wireguard setup is complete. " +
|
|
||||||
"Note Wireguard is a silent protocol and it may or may not work, without giving any error message. " +
|
|
||||||
"Typically i/o timeout errors indicate the Wireguard connection is not working.")
|
|
||||||
ready <- struct{}{}
|
ready <- struct{}{}
|
||||||
|
|
||||||
waitError <- waitAndCleanup()
|
waitError <- waitAndCleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
type waitAndCleanupFunc func() error
|
|
||||||
|
|
||||||
func setupKernelSpace(ctx context.Context,
|
func setupKernelSpace(ctx context.Context,
|
||||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||||
closers *closers, logger Logger, _ Settings, _ userSpaceBackend) (
|
cleanups *cleanup.Cleanups, logger Logger) (
|
||||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||||
) {
|
) {
|
||||||
links, err := netLinker.LinkList()
|
links, err := netLinker.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -164,82 +168,74 @@ func setupKernelSpace(ctx context.Context,
|
|||||||
}
|
}
|
||||||
linkIndex, err = netLinker.LinkAdd(link)
|
linkIndex, err = netLinker.LinkAdd(link)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
return 0, nil, fmt.Errorf("adding link: %w", err)
|
||||||
}
|
}
|
||||||
closers.add("deleting link", stepFive, func() error {
|
cleanups.Add("deleting link", 5, func() error {
|
||||||
return netLinker.LinkDel(linkIndex)
|
return netLinker.LinkDel(linkIndex)
|
||||||
})
|
})
|
||||||
|
|
||||||
waitAndCleanup = func() error {
|
waitAndCleanup = func() error {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
closers.cleanup(logger)
|
cleanups.Cleanup(logger)
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
return linkIndex, waitAndCleanup, nil
|
return linkIndex, waitAndCleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupUserSpaceCommon(ctx context.Context,
|
func setupUserSpace(ctx context.Context,
|
||||||
interfaceName string, netLinker NetLinker, mtu uint32,
|
interfaceName string, netLinker NetLinker, mtu uint32,
|
||||||
closers *closers, logger Logger,
|
cleanups *cleanup.Cleanups, logger Logger) (
|
||||||
settings Settings, b userSpaceBackend,
|
linkIndex uint32, waitAndCleanup func() error, err error,
|
||||||
) (
|
|
||||||
linkIndex uint32, waitAndCleanup waitAndCleanupFunc, err error,
|
|
||||||
) {
|
) {
|
||||||
tun, err := b.createTun(interfaceName, int(mtu))
|
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
return 0, nil, fmt.Errorf("creating TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
cleanups.Add("closing TUN device", 7, tun.Close)
|
||||||
|
|
||||||
tunName, err := tun.Name()
|
tunName, err := tun.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||||
} else if tunName != interfaceName {
|
} else if tunName != interfaceName {
|
||||||
return 0, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||||
ErrCreateTun, interfaceName, tunName)
|
errTunNameMismatch, interfaceName, tunName)
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err := netLinker.LinkByName(interfaceName)
|
link, err := netLinker.LinkByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
return 0, nil, fmt.Errorf("finding link %s: %w", interfaceName, err)
|
||||||
}
|
}
|
||||||
closers.add("deleting link", stepFive, func() error {
|
cleanups.Add("deleting link", 5, func() error {
|
||||||
return netLinker.LinkDel(link.Index)
|
return netLinker.LinkDel(link.Index)
|
||||||
})
|
})
|
||||||
|
|
||||||
bind := b.createBind()
|
bind := conn.NewDefaultBind()
|
||||||
|
|
||||||
closers.add("closing bind", stepSeven, bind.Close)
|
cleanups.Add("closing bind", 7, bind.Close)
|
||||||
|
|
||||||
device := b.createDevice(tun, bind, logger)
|
deviceLogger := makeDeviceLogger(logger)
|
||||||
|
device := device.NewDevice(tun, bind, deviceLogger)
|
||||||
|
|
||||||
closers.add("closing Wireguard device", stepSix, func() error {
|
cleanups.Add("closing Wireguard device", 6, func() error {
|
||||||
device.Close()
|
device.Close()
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
uapiFile, err := uapiOpen(interfaceName)
|
uapiFile, err := UAPIOpen(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
return 0, nil, fmt.Errorf("opening UAPI socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
cleanups.Add("closing UAPI file", 3, uapiFile.Close)
|
||||||
|
|
||||||
uapiListener, err := uapiListen(interfaceName, uapiFile)
|
uapiListener, err := UAPIListen(interfaceName, uapiFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
return 0, nil, fmt.Errorf("listening on UAPI socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
cleanups.Add("closing UAPI listener", 2, uapiListener.Close)
|
||||||
|
|
||||||
if b.preStart != nil {
|
|
||||||
err = b.preStart(device, settings)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// acceptAndHandle exits when uapiListener is closed
|
// acceptAndHandle exits when uapiListener is closed
|
||||||
uapiAcceptErrorCh := make(chan error)
|
uapiAcceptErrorCh := make(chan error)
|
||||||
@@ -251,10 +247,10 @@ func setupUserSpaceCommon(ctx context.Context,
|
|||||||
case err = <-uapiAcceptErrorCh:
|
case err = <-uapiAcceptErrorCh:
|
||||||
close(uapiAcceptErrorCh)
|
close(uapiAcceptErrorCh)
|
||||||
case <-device.Wait():
|
case <-device.Wait():
|
||||||
err = ErrDeviceWaited
|
err = errDeviceWaited
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.cleanup(logger)
|
cleanups.Cleanup(logger)
|
||||||
|
|
||||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||||
|
|
||||||
@@ -264,7 +260,7 @@ func setupUserSpaceCommon(ctx context.Context,
|
|||||||
return link.Index, waitAndCleanup, nil
|
return link.Index, waitAndCleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func acceptAndHandle(uapi net.Listener, device userspaceDevice,
|
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||||
uapiAcceptErrorCh chan<- error,
|
uapiAcceptErrorCh chan<- error,
|
||||||
) {
|
) {
|
||||||
for { // stopped by uapiFile.Close()
|
for { // stopped by uapiFile.Close()
|
||||||
|
|||||||
@@ -46,11 +46,8 @@ type Settings struct {
|
|||||||
// It defaults to false if left unset.
|
// It defaults to false if left unset.
|
||||||
IPv6 *bool
|
IPv6 *bool
|
||||||
// Implementation is the implementation to use.
|
// Implementation is the implementation to use.
|
||||||
// It can be auto, kernelspace, userspace or amneziawg,
|
// It can be auto, kernelspace or userspace, and defaults to auto.
|
||||||
// and defaults to auto.
|
|
||||||
Implementation string
|
Implementation string
|
||||||
// AmneziaWG settings are extra obfuscation parameters
|
|
||||||
AmneziaWG AmneziaSettings
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Settings) SetDefaults() {
|
func (s *Settings) SetDefaults() {
|
||||||
@@ -181,7 +178,7 @@ func (s *Settings) Check() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch s.Implementation {
|
switch s.Implementation {
|
||||||
case "auto", "kernelspace", "userspace", "amneziawg":
|
case "auto", "kernelspace", "userspace":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
|
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package wireguard
|
|
||||||
|
|
||||||
import (
|
|
||||||
amneziaconn "github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
amneziadevice "github.com/amnezia-vpn/amneziawg-go/device"
|
|
||||||
amneziatun "github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
wgconn "golang.zx2c4.com/wireguard/conn"
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
func defaultUserSpaceBackend() userSpaceBackend {
|
|
||||||
return userSpaceBackend{
|
|
||||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
|
||||||
return wgtun.CreateTUN(name, mtu)
|
|
||||||
},
|
|
||||||
createBind: func() bind {
|
|
||||||
return wgconn.NewDefaultBind()
|
|
||||||
},
|
|
||||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
|
||||||
wgtun, _ := td.(wgtun.Device)
|
|
||||||
wgBind, _ := b.(wgconn.Bind)
|
|
||||||
wgLogger := wgdevice.Logger{
|
|
||||||
Verbosef: logger.Debugf,
|
|
||||||
Errorf: logger.Errorf,
|
|
||||||
}
|
|
||||||
device := wgdevice.NewDevice(wgtun, wgBind, &wgLogger)
|
|
||||||
return device
|
|
||||||
},
|
|
||||||
preStart: nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func amneziaUserSpaceBackend() userSpaceBackend {
|
|
||||||
return userSpaceBackend{
|
|
||||||
createTun: func(name string, mtu int) (tunDevice, error) {
|
|
||||||
return amneziatun.CreateTUN(name, mtu)
|
|
||||||
},
|
|
||||||
createBind: func() bind {
|
|
||||||
return amneziaconn.NewDefaultBind()
|
|
||||||
},
|
|
||||||
createDevice: func(td tunDevice, b bind, logger Logger) userspaceDevice {
|
|
||||||
wgamneziaTun, _ := td.(amneziatun.Device)
|
|
||||||
wgamneziaBind, _ := b.(amneziaconn.Bind)
|
|
||||||
wgamneziaLogger := amneziadevice.Logger{
|
|
||||||
Verbosef: logger.Debugf,
|
|
||||||
Errorf: logger.Errorf,
|
|
||||||
}
|
|
||||||
device := amneziadevice.NewDevice(wgamneziaTun, wgamneziaBind, &wgamneziaLogger)
|
|
||||||
return device
|
|
||||||
},
|
|
||||||
preStart: func(ud userspaceDevice, s Settings) error {
|
|
||||||
uapiConfig := s.AmneziaWG.uapiConfig()
|
|
||||||
err := ud.IpcSet(uapiConfig)
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func uapiOpen(name string) (*os.File, error) {
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
return ipc.UAPIOpen(name)
|
return ipc.UAPIOpen(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||||
return ipc.UAPIListen(interfaceName, uapiFile)
|
return ipc.UAPIListen(interfaceName, uapiFile)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
func uapiOpen(name string) (*os.File, error) {
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
func UAPIListen(interfaceName string, uapiFile *os.File) (net.Listener, error) {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user