feat(devrun): add initial implementation of devrun tool

See ./devrun/README.md for more details.
This commit is contained in:
Quentin McGaw
2026-05-01 22:04:00 +00:00
parent 4a78989d9d
commit b1b991b84e
10 changed files with 1910 additions and 0 deletions
+240
View File
@@ -0,0 +1,240 @@
package internal
import (
"bytes"
"encoding/gob"
"fmt"
"maps"
"net/netip"
"slices"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const credentialsFilename = "credentials"
const (
vpnTypeOpenVPN = "openvpn"
vpnTypeWireGuard = "wireguard"
)
type providerCredentials struct {
OpenVPN *openvpnCredentials
WireGuard *wireguardCredentials
}
type openvpnCredentials struct {
Username string
Password string
}
type wireguardCredentials struct {
PrivateKey string
Address string
PresharedKey string
}
func loadCredentials(data []byte) (map[string]providerCredentials, error) {
credentials := make(map[string]providerCredentials)
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&credentials)
if err != nil {
return nil, fmt.Errorf("decoding credentials: %w", err)
}
return credentials, nil
}
func marshalCredentials(credentials map[string]providerCredentials) ([]byte, error) {
buffer := bytes.NewBuffer(nil)
err := gob.NewEncoder(buffer).Encode(credentials)
if err != nil {
return nil, fmt.Errorf("encoding credentials: %w", err)
}
return buffer.Bytes(), nil
}
func validateCredentials(providerNameToCredentials map[string]providerCredentials) error {
for provider, credentials := range providerNameToCredentials {
if credentials.OpenVPN == nil && credentials.WireGuard == nil {
return fmt.Errorf("provider %q has no openvpn or wireguard credentials", provider)
}
if credentials.OpenVPN != nil {
err := validateOpenvpnCredentials(provider, credentials.OpenVPN)
if err != nil {
return err
}
}
if credentials.WireGuard != nil {
err := validateWireguardCredentials(provider, credentials.WireGuard)
if err != nil {
return err
}
}
}
return nil
}
func validateOpenvpnCredentials(provider string, creds *openvpnCredentials) error {
switch {
case creds.Username == "":
return fmt.Errorf("provider %q openvpn credentials are missing the username", provider)
case creds.Password == "":
return fmt.Errorf("provider %q openvpn credentials are missing the password", provider)
}
return nil
}
func validateWireguardCredentials(provider string, creds *wireguardCredentials) error {
if creds.PrivateKey == "" {
return fmt.Errorf("provider %q wireguard credentials are missing the private key", provider)
} else if _, err := wgtypes.ParseKey(creds.PrivateKey); err != nil {
return fmt.Errorf("provider %q wireguard credentials have an invalid private key: %w", provider, err)
}
if creds.Address != "" {
_, err := netip.ParsePrefix(creds.Address)
if err != nil {
return fmt.Errorf("provider %q wireguard credentials have an invalid address %q: %w", provider, creds.Address, err)
}
}
if creds.PresharedKey != "" {
if _, err := wgtypes.ParseKey(creds.PresharedKey); err != nil {
return fmt.Errorf("provider %q wireguard credentials have an invalid preshared key: %w", provider, err)
}
}
return nil
}
func lookupCredentials(credentials map[string]providerCredentials, provider, vpnType string) ([]string, error) {
providerCreds, exists := credentials[provider]
if !exists {
existing := slices.Collect(maps.Keys(credentials))
return nil, fmt.Errorf("no credentials found for provider %q, available providers are: %s",
provider, strings.Join(existing, ", "))
}
switch vpnType {
case vpnTypeWireGuard:
if providerCreds.WireGuard == nil {
return nil, fmt.Errorf("no wireguard credentials found for provider %q", provider)
}
return buildWireGuardEnv(providerCreds.WireGuard), nil
case vpnTypeOpenVPN:
if providerCreds.OpenVPN == nil {
return nil, fmt.Errorf("no openvpn credentials found for provider %q", provider)
}
return buildOpenvpnEnv(providerCreds.OpenVPN), nil
default:
return nil, fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
}
}
func buildWireGuardEnv(creds *wireguardCredentials) []string {
envVars := []string{
"WIREGUARD_PRIVATE_KEY=" + creds.PrivateKey,
}
if creds.Address != "" {
envVars = append(envVars, "WIREGUARD_ADDRESSES="+creds.Address)
}
if creds.PresharedKey != "" {
envVars = append(envVars, "WIREGUARD_PRESHARED_KEY="+creds.PresharedKey)
}
return envVars
}
func buildOpenvpnEnv(creds *openvpnCredentials) []string {
return []string{
"OPENVPN_USER=" + creds.Username,
"OPENVPN_PASSWORD=" + creds.Password,
}
}
func addCredential(credentials map[string]providerCredentials, provider, vpnType string,
openvpnCredentials *openvpnCredentials, wireguardCredentials *wireguardCredentials,
) error {
providerCredentials := credentials[provider]
switch vpnType {
case vpnTypeOpenVPN:
providerCredentials.OpenVPN = openvpnCredentials
case vpnTypeWireGuard:
providerCredentials.WireGuard = wireguardCredentials
default:
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
}
credentials[provider] = providerCredentials
return nil
}
func deleteCredential(credentials map[string]providerCredentials, provider, vpnType string) error {
providerCredentials, exists := credentials[provider]
if !exists {
return fmt.Errorf("provider %q does not exist", provider)
}
switch vpnType {
case vpnTypeOpenVPN:
if providerCredentials.OpenVPN == nil {
return fmt.Errorf("provider %q has no openvpn credentials", provider)
}
providerCredentials.OpenVPN = nil
case vpnTypeWireGuard:
if providerCredentials.WireGuard == nil {
return fmt.Errorf("provider %q has no wireguard credentials", provider)
}
providerCredentials.WireGuard = nil
default:
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
}
if providerCredentials.OpenVPN == nil && providerCredentials.WireGuard == nil {
delete(credentials, provider)
return nil
}
credentials[provider] = providerCredentials
return nil
}
func formatCredentialForDump(provider, vpnType string,
providerCredentials providerCredentials,
) (output string, err error) {
var builder strings.Builder
builder.WriteString("provider: ")
builder.WriteString(provider)
builder.WriteString("\n")
builder.WriteString("vpn_type: ")
builder.WriteString(vpnType)
builder.WriteString("\n")
switch vpnType {
case vpnTypeOpenVPN:
if providerCredentials.OpenVPN == nil {
return "", fmt.Errorf("no openvpn credentials found for provider %q", provider)
}
builder.WriteString("username: ")
builder.WriteString(providerCredentials.OpenVPN.Username)
builder.WriteString("\n")
builder.WriteString("password: ")
builder.WriteString(providerCredentials.OpenVPN.Password)
case vpnTypeWireGuard:
if providerCredentials.WireGuard == nil {
return "", fmt.Errorf("no wireguard credentials found for provider %q", provider)
}
builder.WriteString("private_key: ")
builder.WriteString(providerCredentials.WireGuard.PrivateKey)
builder.WriteString("\n")
builder.WriteString("address: ")
builder.WriteString(providerCredentials.WireGuard.Address)
builder.WriteString("\n")
builder.WriteString("preshared_key: ")
builder.WriteString(providerCredentials.WireGuard.PresharedKey)
default:
return "", fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
}
return builder.String(), nil
}
+343
View File
@@ -0,0 +1,343 @@
package internal
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func Test_addCredential(t *testing.T) {
t.Parallel()
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
testCases := map[string]struct {
initialCredentials map[string]providerCredentials
provider string
vpnType string
openvpnCredentials *openvpnCredentials
wireguardCreds *wireguardCredentials
expectedLength int
expectedOpenVPN bool
expectedWireGuard bool
}{
"adds_openvpn_credentials": {
initialCredentials: map[string]providerCredentials{},
provider: "protonvpn",
vpnType: "openvpn",
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
expectedLength: 1,
expectedOpenVPN: true,
},
"adds_wireguard_credentials": {
initialCredentials: map[string]providerCredentials{},
provider: "mullvad",
vpnType: "wireguard",
wireguardCreds: &wireguardCredentials{
PrivateKey: wireguardPrivateKey.String(),
Address: "10.0.0.2/32",
},
expectedLength: 1,
expectedWireGuard: true,
},
"preserves_other_protocol": {
initialCredentials: map[string]providerCredentials{
"protonvpn": {
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
},
},
provider: "protonvpn",
vpnType: "openvpn",
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
expectedLength: 1,
expectedOpenVPN: true,
expectedWireGuard: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
credentials := cloneCredentials(testCase.initialCredentials)
err := addCredential(credentials, testCase.provider, testCase.vpnType,
testCase.openvpnCredentials, testCase.wireguardCreds)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
providerCredentials := credentials[testCase.provider]
if len(credentials) != testCase.expectedLength {
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
}
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
}
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
}
})
}
}
func Test_deleteCredential(t *testing.T) {
t.Parallel()
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
testCases := map[string]struct {
initialCredentials map[string]providerCredentials
provider string
vpnType string
expectedLength int
expectedOpenVPN bool
expectedWireGuard bool
}{
"deletes_openvpn_only": {
initialCredentials: map[string]providerCredentials{
"protonvpn": {
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
},
},
provider: "protonvpn",
vpnType: "openvpn",
expectedLength: 1,
expectedWireGuard: true,
},
"deletes_last_protocol_and_provider": {
initialCredentials: map[string]providerCredentials{
"protonvpn": {
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
},
},
provider: "protonvpn",
vpnType: "openvpn",
expectedLength: 0,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
credentials := cloneCredentials(testCase.initialCredentials)
err := deleteCredential(credentials, testCase.provider, testCase.vpnType)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(credentials) != testCase.expectedLength {
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
}
providerCredentials, exists := credentials[testCase.provider]
if !exists {
return
}
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
}
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
}
})
}
}
func Test_validateCredentials(t *testing.T) {
t.Parallel()
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
testCases := map[string]struct {
credentials map[string]providerCredentials
wantError bool
}{
"both_protocols_valid": {
credentials: map[string]providerCredentials{
"protonvpn": {
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
},
},
},
"invalid_wireguard_when_both_present": {
credentials: map[string]providerCredentials{
"protonvpn": {
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
WireGuard: &wireguardCredentials{PrivateKey: "invalid"},
},
},
wantError: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
err := validateCredentials(testCase.credentials)
if testCase.wantError && err == nil {
t.Fatal("expected an error but got nil")
}
if !testCase.wantError && err != nil {
t.Fatalf("expected no error, got %v", err)
}
})
}
}
func Test_marshalLoadCredentials(t *testing.T) {
t.Parallel()
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
credentials := map[string]providerCredentials{
"mullvad": {
WireGuard: &wireguardCredentials{
PrivateKey: wireguardPrivateKey.String(),
Address: "10.0.0.2/32",
},
},
"protonvpn": {
OpenVPN: &openvpnCredentials{
Username: "user",
Password: "pass",
},
},
}
encoded, err := marshalCredentials(credentials)
if err != nil {
t.Fatalf("unexpected marshal error: %v", err)
}
decoded, err := loadCredentials(encoded)
if err != nil {
t.Fatalf("unexpected load error: %v", err)
}
if len(decoded) != len(credentials) {
t.Fatalf("expected %d providers, got %d", len(credentials), len(decoded))
}
if decoded["mullvad"].WireGuard == nil {
t.Fatal("expected mullvad wireguard credentials to be present")
}
if decoded["protonvpn"].OpenVPN == nil {
t.Fatal("expected protonvpn openvpn credentials to be present")
}
if decoded["protonvpn"].OpenVPN.Password != "pass" {
t.Fatalf("expected protonvpn password %q, got %q", "pass", decoded["protonvpn"].OpenVPN.Password)
}
}
func Test_formatCredentialForDump(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
provider string
vpnType string
providerCredentials providerCredentials
expectedOutput string
wantError bool
}{
"openvpn": {
provider: "protonvpn",
vpnType: vpnTypeOpenVPN,
providerCredentials: providerCredentials{
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
},
expectedOutput: "provider: protonvpn\n" +
"vpn_type: openvpn\n" +
"username: user\n" +
"password: pass",
},
"wireguard": {
provider: "mullvad",
vpnType: vpnTypeWireGuard,
providerCredentials: providerCredentials{
WireGuard: &wireguardCredentials{
PrivateKey: "private",
Address: "10.0.0.2/32",
PresharedKey: "preshared",
},
},
expectedOutput: "provider: mullvad\n" +
"vpn_type: wireguard\n" +
"private_key: private\n" +
"address: 10.0.0.2/32\n" +
"preshared_key: preshared",
},
"missing_protocol": {
provider: "protonvpn",
vpnType: vpnTypeOpenVPN,
wantError: true,
},
"unknown_protocol": {
provider: "protonvpn",
vpnType: "other",
wantError: true,
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
output, err := formatCredentialForDump(
testCase.provider,
testCase.vpnType,
testCase.providerCredentials,
)
if testCase.wantError {
if err == nil {
t.Fatal("expected an error but got nil")
}
return
}
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if output != testCase.expectedOutput {
t.Fatalf("expected output %q, got %q", testCase.expectedOutput, output)
}
})
}
}
func cloneCredentials(credentials map[string]providerCredentials) map[string]providerCredentials {
clone := make(map[string]providerCredentials, len(credentials))
for provider, providerCredentials := range credentials {
copied := providerCredentials
if providerCredentials.OpenVPN != nil {
openvpnCredentials := *providerCredentials.OpenVPN
copied.OpenVPN = &openvpnCredentials
}
if providerCredentials.WireGuard != nil {
wireguardCredentials := *providerCredentials.WireGuard
copied.WireGuard = &wireguardCredentials
}
clone[provider] = copied
}
return clone
}
+521
View File
@@ -0,0 +1,521 @@
package internal
import (
"bufio"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
"maps"
"os"
"slices"
"strings"
"sync"
"syscall"
"golang.org/x/crypto/scrypt"
"golang.org/x/term"
)
// Encryption format: [16-byte salt][12-byte nonce][AES-256-GCM ciphertext+tag]
// Key derivation: scrypt(password, salt, N=32768, r=8, p=1, keyLen=32)
const (
saltSize = 16
nonceSize = 12
keySize = 32
scryptN = 32768
scryptR = 8
scryptP = 1
)
// AddCredential prompts for credential values and stores them in the encrypted credentials file.
func AddCredential(ctx context.Context, provider, vpnType string) error {
credentials, password, err := loadCredentialsForMutation(ctx)
if err != nil {
return err
}
err = promptAndAddCredential(ctx, credentials, provider, vpnType)
if err != nil {
return err
}
err = validateCredentials(credentials)
if err != nil {
return fmt.Errorf("validating credentials: %w", err)
}
err = writeEncryptedCredentials(credentials, password)
if err != nil {
return err
}
fmt.Printf(
"Credentials for provider %q and vpn type %q saved to %s\n",
provider, vpnType, credentialsFilename,
)
return nil
}
// DeleteCredential removes credentials for a provider and VPN type
// from the encrypted credentials file.
func DeleteCredential(ctx context.Context, provider, vpnType string) error {
credentials, password, err := loadExistingCredentialsForMutation(ctx)
if err != nil {
return err
}
err = deleteCredential(credentials, provider, vpnType)
if err != nil {
return err
}
err = writeEncryptedCredentials(credentials, password)
if err != nil {
return err
}
fmt.Printf(
"Credentials for provider %q and vpn type %q removed from %s\n",
provider, vpnType, credentialsFilename,
)
return nil
}
// DumpCredential decrypts the credential store and prints one provider/vpn-type entry.
func DumpCredential(ctx context.Context, provider, vpnType string) error {
credentials, err := decryptCredentials(ctx)
if err != nil {
return err
}
providerCredentials, exists := credentials[provider]
if !exists {
existingProviders := slices.Collect(maps.Keys(credentials))
return fmt.Errorf("provider %q does not exist, available providers are: %s",
provider, strings.Join(existingProviders, ", "))
}
output, err := formatCredentialForDump(provider, vpnType, providerCredentials)
if err != nil {
return err
}
fmt.Println(output)
return nil
}
// decryptCredentials reads the encrypted credentials file,
// prompts for a password, and returns the decrypted credentials.
func decryptCredentials(ctx context.Context) (map[string]providerCredentials, error) {
password, err := readSecret(ctx, "Enter credentials password: ", false)
if err != nil {
return nil, fmt.Errorf("reading password: %w", err)
}
plaintext, err := decryptCredentialsFile(password)
if err != nil {
return nil, err
}
credentials, err := loadCredentials(plaintext)
if err != nil {
return nil, fmt.Errorf("loading credentials: %w", err)
}
return credentials, nil
}
func loadCredentialsForMutation(ctx context.Context) (
credentials map[string]providerCredentials,
password []byte,
err error,
) {
_, err = os.Stat(credentialsFilename)
if os.IsNotExist(err) {
password, err = readPasswordConfirmed(ctx,
"Enter new credentials password: ",
"Confirm new credentials password: ",
)
if err != nil {
return nil, nil, fmt.Errorf("reading password: %w", err)
}
return make(map[string]providerCredentials), password, nil
}
if err != nil {
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
}
password, err = readSecret(ctx, "Enter credentials password: ", false)
if err != nil {
return nil, nil, fmt.Errorf("reading password: %w", err)
}
plaintext, err := decryptCredentialsFile(password)
if err != nil {
return nil, nil, err
}
credentials, err = loadCredentials(plaintext)
if err != nil {
return nil, nil, fmt.Errorf("loading credentials: %w", err)
}
return credentials, password, nil
}
func loadExistingCredentialsForMutation(ctx context.Context) (
credentials map[string]providerCredentials,
password []byte,
err error,
) {
_, err = os.Stat(credentialsFilename)
if os.IsNotExist(err) {
return nil, nil, fmt.Errorf("%s does not exist", credentialsFilename)
}
if err != nil {
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
}
password, err = readSecret(ctx, "Enter credentials password: ", false)
if err != nil {
return nil, nil, fmt.Errorf("reading password: %w", err)
}
plaintext, err := decryptCredentialsFile(password)
if err != nil {
return nil, nil, err
}
credentials, err = loadCredentials(plaintext)
if err != nil {
return nil, nil, fmt.Errorf("loading credentials: %w", err)
}
return credentials, password, nil
}
func promptAndAddCredential(
ctx context.Context,
credentials map[string]providerCredentials,
provider, vpnType string,
) error {
switch vpnType {
case vpnTypeOpenVPN:
username, err := readLine(ctx, "OpenVPN username: ", false)
if err != nil {
return fmt.Errorf("reading username: %w", err)
}
password, err := readSecret(ctx, "OpenVPN password: ", false)
if err != nil {
return fmt.Errorf("reading password: %w", err)
}
openvpnCredentials := &openvpnCredentials{
Username: username,
Password: string(password),
}
err = validateOpenvpnCredentials(provider, openvpnCredentials)
if err != nil {
return err
}
return addCredential(credentials, provider, vpnType, openvpnCredentials, nil)
case vpnTypeWireGuard:
privateKey, err := readSecret(ctx, "WireGuard private key: ", false)
if err != nil {
return fmt.Errorf("reading private key: %w", err)
}
address, err := readLine(ctx, "WireGuard address (optional): ", true)
if err != nil {
return fmt.Errorf("reading address: %w", err)
}
presharedKey, err := readSecret(
ctx,
"WireGuard preshared key (optional): ",
true,
)
if err != nil {
return fmt.Errorf("reading preshared key: %w", err)
}
wireguardCredentials := &wireguardCredentials{
PrivateKey: string(privateKey),
Address: address,
PresharedKey: string(presharedKey),
}
err = validateWireguardCredentials(provider, wireguardCredentials)
if err != nil {
return err
}
return addCredential(credentials, provider, vpnType, nil, wireguardCredentials)
default:
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
}
}
func writeEncryptedCredentials(
credentials map[string]providerCredentials,
password []byte,
) error {
plaintext, err := marshalCredentials(credentials)
if err != nil {
return fmt.Errorf("encoding credentials: %w", err)
}
encrypted, err := encryptData(plaintext, password)
if err != nil {
return fmt.Errorf("encrypting credentials: %w", err)
}
const filePerms = 0o600
err = os.WriteFile(credentialsFilename, encrypted, filePerms)
if err != nil {
return fmt.Errorf("writing %s: %w", credentialsFilename, err)
}
return nil
}
func decryptCredentialsFile(password []byte) ([]byte, error) {
encryptedData, err := os.ReadFile(credentialsFilename)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", credentialsFilename, err)
}
plaintext, err := decryptData(encryptedData, password)
if err != nil {
return nil, fmt.Errorf("decrypting credentials: %w", err)
}
return plaintext, nil
}
func readSecret(ctx context.Context, prompt string, allowEmpty bool) ([]byte, error) {
fmt.Print(prompt)
passwordFD, err := syscall.Dup(syscall.Stdin)
if err != nil {
fmt.Println()
return nil, fmt.Errorf("duplicating stdin file descriptor: %w", err)
}
var closeFDOnce sync.Once
closePasswordFD := func() {
closeFDOnce.Do(func() {
_ = syscall.Close(passwordFD)
})
}
passwordResult := make(chan struct {
password []byte
err error
})
go func() {
password, err := term.ReadPassword(passwordFD)
closePasswordFD()
result := struct {
password []byte
err error
}{
password: password,
err: err,
}
select {
case <-ctx.Done():
return
case passwordResult <- result:
}
}()
select {
case <-ctx.Done():
closePasswordFD()
fmt.Println()
return nil, ctx.Err()
case result := <-passwordResult:
closePasswordFD()
fmt.Println()
if result.err != nil {
return nil, fmt.Errorf("reading hidden input from terminal: %w", result.err)
}
if len(result.password) == 0 && !allowEmpty {
return nil, fmt.Errorf("value cannot be empty")
}
return result.password, nil
}
}
func readLine(ctx context.Context, prompt string, allowEmpty bool) (string, error) {
fmt.Print(prompt)
inputFD, err := syscall.Dup(syscall.Stdin)
if err != nil {
fmt.Println()
return "", fmt.Errorf("duplicating stdin file descriptor: %w", err)
}
var closeFDOnce sync.Once
closeInputFD := func() {
closeFDOnce.Do(func() {
_ = syscall.Close(inputFD)
})
}
inputResult := make(chan struct {
value string
err error
})
go func() {
inputFile := os.NewFile(uintptr(inputFD), "stdin")
reader := bufio.NewReader(inputFile)
value, err := reader.ReadString('\n')
closeInputFD()
value = strings.TrimRight(value, "\r\n")
if err == io.EOF {
err = nil
}
result := struct {
value string
err error
}{
value: value,
err: err,
}
select {
case <-ctx.Done():
return
case inputResult <- result:
}
}()
select {
case <-ctx.Done():
closeInputFD()
fmt.Println()
return "", ctx.Err()
case result := <-inputResult:
closeInputFD()
if result.err != nil {
return "", fmt.Errorf("reading line from terminal: %w", result.err)
}
if result.value == "" && !allowEmpty {
return "", fmt.Errorf("value cannot be empty")
}
return result.value, nil
}
}
func readPasswordConfirmed(
ctx context.Context,
prompt, confirmationPrompt string,
) ([]byte, error) {
password, err := readSecret(ctx, prompt, false)
if err != nil {
return nil, err
}
confirmation, err := readSecret(ctx, confirmationPrompt, false)
if err != nil {
return nil, err
}
if string(password) != string(confirmation) {
return nil, fmt.Errorf("passwords do not match")
}
return password, nil
}
func deriveKey(password, salt []byte) ([]byte, error) {
key, err := scrypt.Key(password, salt, scryptN, scryptR, scryptP, keySize)
if err != nil {
return nil, fmt.Errorf("deriving key with scrypt: %w", err)
}
return key, nil
}
func encryptData(plaintext, password []byte) ([]byte, error) {
salt := make([]byte, saltSize)
_, err := io.ReadFull(rand.Reader, salt)
if err != nil {
return nil, fmt.Errorf("generating salt: %w", err)
}
key, err := deriveKey(password, salt)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("creating AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("creating GCM: %w", err)
}
nonce := make([]byte, nonceSize)
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, fmt.Errorf("generating nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
result := make([]byte, 0, saltSize+nonceSize+len(ciphertext))
result = append(result, salt...)
result = append(result, nonce...)
result = append(result, ciphertext...)
return result, nil
}
func decryptData(data, password []byte) ([]byte, error) {
const minSize = saltSize + nonceSize + 16 // 16 is the GCM tag size
if len(data) < minSize {
return nil, fmt.Errorf("encrypted data too short: %d bytes", len(data))
}
salt := data[:saltSize]
nonce := data[saltSize : saltSize+nonceSize]
ciphertext := data[saltSize+nonceSize:]
key, err := deriveKey(password, salt)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("creating AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("creating GCM: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decrypting data (wrong password?): %w", err)
}
return plaintext, nil
}
+351
View File
@@ -0,0 +1,351 @@
package internal
import (
"context"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/stdcopy"
"github.com/docker/go-connections/nat"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"golang.org/x/term"
)
type containerOptions struct {
env []string
binds []string
ports nat.PortMap
dns []string
devices []container.DeviceMapping
labels map[string]string
capAdd []string
}
// Run decrypts credentials, builds the container environment, and runs a Gluetun container.
// extraArgs is the list of additional flags (e.g. ["-e", "PORT_FORWARDING=on", "-v", "/host:/container"]).
func Run(ctx context.Context, provider, vpnType string, extraArgs []string,
forceKill <-chan struct{},
) error {
credentials, err := decryptCredentials(ctx)
if err != nil {
return fmt.Errorf("loading credentials: %w", err)
}
credentialEnvVars, err := lookupCredentials(credentials, provider, vpnType)
if err != nil {
return err
}
extraOpts, err := parseExtraArgs(extraArgs)
if err != nil {
return fmt.Errorf("parsing extra flags: %w", err)
}
opts := extraOpts
opts.env = append(opts.env,
"VPN_SERVICE_PROVIDER="+provider,
"VPN_TYPE="+vpnType,
"LOG_LEVEL=debug",
)
opts.env = append(opts.env, credentialEnvVars...)
opts.capAdd = append(opts.capAdd, "NET_ADMIN")
return runContainer(ctx, opts, forceKill)
}
// parseExtraArgs parses extra arguments and maps them to container options.
// Supported flags:
//
// -e, --env KEY=VALUE - environment variable
// -v, --volume SPEC - volume mount (e.g., "/host:/container" or "name:/container")
// -p, --publish PORT:PORT - port mapping
// --dns IP - DNS server
// --device SPEC - device access (e.g., "/dev/net/tun")
// --label KEY=VALUE - container label
// --cap-add CAPABILITY - add Linux capability (e.g., "SYS_PTRACE")
func parseExtraArgs(args []string) (opts containerOptions, err error) { //nolint:gocognit,gocyclo
opts = containerOptions{
ports: make(nat.PortMap),
labels: make(map[string]string),
}
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case arg == "-e" || arg == "--env":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
opts.env = append(opts.env, args[i])
case strings.HasPrefix(arg, "-e="):
opts.env = append(opts.env, strings.TrimPrefix(arg, "-e="))
case strings.HasPrefix(arg, "--env="):
opts.env = append(opts.env, strings.TrimPrefix(arg, "--env="))
case arg == "-v" || arg == "--volume":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
opts.binds = append(opts.binds, args[i])
case strings.HasPrefix(arg, "-v="):
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "-v="))
case strings.HasPrefix(arg, "--volume="):
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "--volume="))
case arg == "-p" || arg == "--publish":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
if err := parsePortMapping(opts.ports, args[i]); err != nil {
return opts, fmt.Errorf("parsing port mapping: %w", err)
}
case strings.HasPrefix(arg, "-p="):
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "-p=")); err != nil {
return opts, fmt.Errorf("parsing port mapping: %w", err)
}
case strings.HasPrefix(arg, "--publish="):
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "--publish=")); err != nil {
return opts, fmt.Errorf("parsing port mapping: %w", err)
}
case arg == "--dns":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
opts.dns = append(opts.dns, args[i])
case strings.HasPrefix(arg, "--dns="):
opts.dns = append(opts.dns, strings.TrimPrefix(arg, "--dns="))
case arg == "--device":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
parseDeviceMapping(&opts.devices, args[i])
case strings.HasPrefix(arg, "--device="):
parseDeviceMapping(&opts.devices, strings.TrimPrefix(arg, "--device="))
case arg == "--label":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
parseLabel(opts.labels, args[i])
case strings.HasPrefix(arg, "--label="):
parseLabel(opts.labels, strings.TrimPrefix(arg, "--label="))
case arg == "--cap-add":
if i+1 >= len(args) {
return opts, fmt.Errorf("flag %q requires an argument", arg)
}
i++
opts.capAdd = append(opts.capAdd, args[i])
case strings.HasPrefix(arg, "--cap-add="):
opts.capAdd = append(opts.capAdd, strings.TrimPrefix(arg, "--cap-add="))
default:
return opts, fmt.Errorf("unsupported flag %q", arg)
}
}
return opts, nil
}
func parsePortMapping(portMap nat.PortMap, spec string) error {
port, bindings, err := nat.ParsePortSpecs([]string{spec})
if err != nil {
return err
}
for p, binding := range bindings {
portMap[p] = binding
}
for p := range port {
if _, exists := portMap[p]; !exists {
portMap[p] = []nat.PortBinding{}
}
}
return nil
}
func parseDeviceMapping(devices *[]container.DeviceMapping, spec string) {
parts := strings.SplitN(spec, ":", 3) //nolint:mnd
pathOnHost := parts[0]
pathInContainer := pathOnHost
permissions := "rwm"
if len(parts) >= 2 { //nolint:mnd
pathInContainer = parts[1]
}
if len(parts) >= 3 { //nolint:mnd
permissions = parts[2]
}
*devices = append(*devices, container.DeviceMapping{
PathOnHost: pathOnHost,
PathInContainer: pathInContainer,
CgroupPermissions: permissions,
})
}
func parseLabel(labels map[string]string, kv string) {
parts := strings.SplitN(kv, "=", 2) //nolint:mnd
key := parts[0]
value := ""
if len(parts) > 1 {
value = parts[1]
}
labels[key] = value
}
func runContainer(ctx context.Context, opts containerOptions, forceKill <-chan struct{}) error {
dockerClient, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return fmt.Errorf("creating docker client: %w", err)
}
defer dockerClient.Close()
hasTTY := term.IsTerminal(int(os.Stdout.Fd()))
containerConfig := &container.Config{
Image: "qmcgaw/gluetun",
Env: opts.env,
Labels: opts.labels,
Tty: hasTTY,
}
mounts := make([]mount.Mount, 0, len(opts.binds))
for _, bind := range opts.binds {
m, err := parseBindMount(bind)
if err != nil {
return fmt.Errorf("parsing bind mount %q: %w", bind, err)
}
mounts = append(mounts, m)
}
hostConfig := &container.HostConfig{
AutoRemove: true,
CapAdd: opts.capAdd,
Binds: opts.binds,
Mounts: mounts,
PortBindings: opts.ports,
DNS: opts.dns,
}
hostConfig.Devices = opts.devices
networkConfig := &network.NetworkingConfig{}
platform := (*v1.Platform)(nil)
const containerName = "gluetun"
response, err := dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, networkConfig, platform, containerName)
if err != nil {
return fmt.Errorf("creating container: %w", err)
}
for _, warning := range response.Warnings {
fmt.Fprintln(os.Stderr, "container creation warning:", warning)
}
containerID := response.ID
err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{})
if err != nil {
return fmt.Errorf("starting container: %w", err)
}
fmt.Printf("Container started (id: %.12s)\n", containerID)
streamLogsErr := make(chan error, 1)
go func() {
streamLogsErr <- streamLogs(context.Background(), dockerClient, containerID, hasTTY)
}()
contextDone := ctx.Done()
forceKillSignal := forceKill
for {
select {
case err := <-streamLogsErr:
if err != nil {
return err
}
return nil
case <-contextDone:
fmt.Fprintln(os.Stderr, "\nReceived interrupt, stopping container (5s timeout)...")
err = stopContainer(dockerClient, containerID)
if err != nil {
fmt.Fprintln(os.Stderr, "stopping container:", err)
}
contextDone = nil
case <-forceKillSignal:
fmt.Fprintln(os.Stderr, "\nReceived second interrupt, killing container...")
err = killContainer(dockerClient, containerID)
if err != nil {
fmt.Fprintln(os.Stderr, "killing container:", err)
}
forceKillSignal = nil
}
}
}
func parseBindMount(bind string) (mount.Mount, error) {
parts := strings.SplitN(bind, ":", 3) //nolint:mnd
if len(parts) < 2 { //nolint:mnd
return mount.Mount{}, fmt.Errorf("invalid bind mount format: %q (expected source:target[:mode])", bind)
}
source := parts[0]
target := parts[1]
readOnly := len(parts) > 2 && strings.Contains(parts[2], "ro") //nolint:mnd
return mount.Mount{
Type: mount.TypeBind,
Source: source,
Target: target,
ReadOnly: readOnly,
}, nil
}
func stopContainer(dockerClient *client.Client, containerID string) error {
const stopTimeout = 5 * time.Second
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
defer stopCancel()
timeoutSeconds := int(stopTimeout.Seconds())
return dockerClient.ContainerStop(stopCtx, containerID, container.StopOptions{Timeout: &timeoutSeconds})
}
func killContainer(dockerClient *client.Client, containerID string) error {
return dockerClient.ContainerKill(context.Background(), containerID, "KILL")
}
func streamLogs(ctx context.Context, dockerClient *client.Client, containerID string, hasTTY bool) error {
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
Timestamps: false,
}
reader, err := dockerClient.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return fmt.Errorf("getting container logs: %w", err)
}
defer reader.Close()
if hasTTY {
_, err = io.Copy(os.Stdout, reader)
} else {
_, err = stdcopy.StdCopy(os.Stdout, os.Stderr, reader)
}
if err != nil && err != io.EOF {
return fmt.Errorf("streaming container logs: %w", err)
}
return nil
}