mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
351 lines
9.2 KiB
Go
351 lines
9.2 KiB
Go
package internal
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
func Test_addCredential(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := map[string]struct {
|
|
initialCredentials map[string]providerCredentials
|
|
provider string
|
|
vpnType string
|
|
openvpnCredentials *openvpnCredentials
|
|
wireguardCreds *wireguardCredentials
|
|
expectedLength int
|
|
expectedOpenVPN bool
|
|
expectedWireGuard bool
|
|
}{
|
|
"adds_openvpn_credentials": {
|
|
initialCredentials: map[string]providerCredentials{},
|
|
provider: "protonvpn",
|
|
vpnType: "openvpn",
|
|
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
expectedLength: 1,
|
|
expectedOpenVPN: true,
|
|
},
|
|
"adds_wireguard_credentials": {
|
|
initialCredentials: map[string]providerCredentials{},
|
|
provider: "mullvad",
|
|
vpnType: "wireguard",
|
|
wireguardCreds: &wireguardCredentials{
|
|
PrivateKey: wireguardPrivateKey.String(),
|
|
Address: "10.0.0.2/32",
|
|
},
|
|
expectedLength: 1,
|
|
expectedWireGuard: true,
|
|
},
|
|
"preserves_other_protocol": {
|
|
initialCredentials: map[string]providerCredentials{
|
|
"protonvpn": {
|
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
|
},
|
|
},
|
|
provider: "protonvpn",
|
|
vpnType: "openvpn",
|
|
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
expectedLength: 1,
|
|
expectedOpenVPN: true,
|
|
expectedWireGuard: true,
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
credentials := cloneCredentials(testCase.initialCredentials)
|
|
|
|
err := addCredential(credentials, testCase.provider, testCase.vpnType,
|
|
testCase.openvpnCredentials, testCase.wireguardCreds)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
providerCredentials := credentials[testCase.provider]
|
|
if len(credentials) != testCase.expectedLength {
|
|
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
|
}
|
|
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
|
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
|
}
|
|
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
|
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_deleteCredential(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := map[string]struct {
|
|
initialCredentials map[string]providerCredentials
|
|
provider string
|
|
vpnType string
|
|
expectedLength int
|
|
expectedOpenVPN bool
|
|
expectedWireGuard bool
|
|
}{
|
|
"deletes_openvpn_only": {
|
|
initialCredentials: map[string]providerCredentials{
|
|
"protonvpn": {
|
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
|
},
|
|
},
|
|
provider: "protonvpn",
|
|
vpnType: "openvpn",
|
|
expectedLength: 1,
|
|
expectedWireGuard: true,
|
|
},
|
|
"deletes_last_protocol_and_provider": {
|
|
initialCredentials: map[string]providerCredentials{
|
|
"protonvpn": {
|
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
},
|
|
},
|
|
provider: "protonvpn",
|
|
vpnType: "openvpn",
|
|
expectedLength: 0,
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
credentials := cloneCredentials(testCase.initialCredentials)
|
|
|
|
err := deleteCredential(credentials, testCase.provider, testCase.vpnType)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if len(credentials) != testCase.expectedLength {
|
|
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
|
}
|
|
|
|
providerCredentials, exists := credentials[testCase.provider]
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
|
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
|
}
|
|
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
|
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_validateCredentials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := map[string]struct {
|
|
credentials map[string]providerCredentials
|
|
wantError bool
|
|
}{
|
|
"both_protocols_valid": {
|
|
credentials: map[string]providerCredentials{
|
|
"protonvpn": {
|
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
|
},
|
|
},
|
|
},
|
|
"invalid_wireguard_when_both_present": {
|
|
credentials: map[string]providerCredentials{
|
|
"protonvpn": {
|
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
|
WireGuard: &wireguardCredentials{PrivateKey: "invalid"},
|
|
},
|
|
},
|
|
wantError: true,
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
err := validateCredentials(testCase.credentials)
|
|
if testCase.wantError && err == nil {
|
|
t.Fatal("expected an error but got nil")
|
|
}
|
|
if !testCase.wantError && err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_marshalLoadCredentials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
credentials := map[string]providerCredentials{
|
|
"mullvad": {
|
|
WireGuard: &wireguardCredentials{
|
|
PrivateKey: wireguardPrivateKey.String(),
|
|
Address: "10.0.0.2/32",
|
|
},
|
|
},
|
|
"protonvpn": {
|
|
OpenVPN: &openvpnCredentials{
|
|
Username: "user",
|
|
Password: "pass",
|
|
},
|
|
},
|
|
}
|
|
|
|
encoded, err := marshalCredentials(credentials)
|
|
if err != nil {
|
|
t.Fatalf("unexpected marshal error: %v", err)
|
|
}
|
|
|
|
decoded, err := loadCredentials(encoded)
|
|
if err != nil {
|
|
t.Fatalf("unexpected load error: %v", err)
|
|
}
|
|
|
|
if len(decoded) != len(credentials) {
|
|
t.Fatalf("expected %d providers, got %d", len(credentials), len(decoded))
|
|
}
|
|
|
|
if decoded["mullvad"].WireGuard == nil {
|
|
t.Fatal("expected mullvad wireguard credentials to be present")
|
|
}
|
|
if decoded["protonvpn"].OpenVPN == nil {
|
|
t.Fatal("expected protonvpn openvpn credentials to be present")
|
|
}
|
|
if decoded["protonvpn"].OpenVPN.Password != "pass" {
|
|
t.Fatalf("expected protonvpn password %q, got %q", "pass", decoded["protonvpn"].OpenVPN.Password)
|
|
}
|
|
}
|
|
|
|
func Test_formatCredentialForDump(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := map[string]struct {
|
|
provider string
|
|
vpnType string
|
|
providerCredentials providerCredentials
|
|
expectedOutput string
|
|
wantError bool
|
|
}{
|
|
"openvpn": {
|
|
provider: "protonvpn",
|
|
vpnType: vpnTypeOpenVPN,
|
|
providerCredentials: providerCredentials{
|
|
OpenVPN: &openvpnCredentials{
|
|
Username: "user",
|
|
Password: "pass",
|
|
Key: "key",
|
|
Cert: "cert",
|
|
},
|
|
},
|
|
expectedOutput: "provider: protonvpn\n" +
|
|
"vpn_type: openvpn\n" +
|
|
"username: user\n" +
|
|
"password: pass\n" +
|
|
"key: key\n" +
|
|
"cert: cert\n",
|
|
},
|
|
"wireguard": {
|
|
provider: "mullvad",
|
|
vpnType: vpnTypeWireGuard,
|
|
providerCredentials: providerCredentials{
|
|
WireGuard: &wireguardCredentials{
|
|
PrivateKey: "private",
|
|
Address: "10.0.0.2/32",
|
|
PresharedKey: "preshared",
|
|
},
|
|
},
|
|
expectedOutput: "provider: mullvad\n" +
|
|
"vpn_type: wireguard\n" +
|
|
"private_key: private\n" +
|
|
"address: 10.0.0.2/32\n" +
|
|
"preshared_key: preshared",
|
|
},
|
|
"missing_protocol": {
|
|
provider: "protonvpn",
|
|
vpnType: vpnTypeOpenVPN,
|
|
wantError: true,
|
|
},
|
|
"unknown_protocol": {
|
|
provider: "protonvpn",
|
|
vpnType: "other",
|
|
wantError: true,
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
output, err := formatCredentialForDump(
|
|
testCase.provider,
|
|
testCase.vpnType,
|
|
testCase.providerCredentials,
|
|
)
|
|
|
|
if testCase.wantError {
|
|
if err == nil {
|
|
t.Fatal("expected an error but got nil")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
if output != testCase.expectedOutput {
|
|
t.Fatalf("expected output %q, got %q", testCase.expectedOutput, output)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func cloneCredentials(credentials map[string]providerCredentials) map[string]providerCredentials {
|
|
clone := make(map[string]providerCredentials, len(credentials))
|
|
for provider, providerCredentials := range credentials {
|
|
copied := providerCredentials
|
|
if providerCredentials.OpenVPN != nil {
|
|
openvpnCredentials := *providerCredentials.OpenVPN
|
|
copied.OpenVPN = &openvpnCredentials
|
|
}
|
|
if providerCredentials.WireGuard != nil {
|
|
wireguardCredentials := *providerCredentials.WireGuard
|
|
copied.WireGuard = &wireguardCredentials
|
|
}
|
|
clone[provider] = copied
|
|
}
|
|
return clone
|
|
}
|