mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(devrun): add initial implementation of devrun tool
See ./devrun/README.md for more details.
This commit is contained in:
@@ -7,3 +7,4 @@ Dockerfile
|
|||||||
LICENSE
|
LICENSE
|
||||||
README.md
|
README.md
|
||||||
title.svg
|
title.svg
|
||||||
|
devrun
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
credentials
|
||||||
@@ -0,0 +1,152 @@
|
|||||||
|
# devrun
|
||||||
|
|
||||||
|
`devrun` is a small development helper for starting a local `qmcgaw/gluetun` Docker container with provider credentials stored in an encrypted file.
|
||||||
|
|
||||||
|
It solves two practical problems for local development:
|
||||||
|
|
||||||
|
- keeping VPN credentials out of the shell history and out of a plaintext file once setup is complete;
|
||||||
|
- quickly starting a Gluetun container for a specific provider and VPN type with a small set of extra Docker runtime options.
|
||||||
|
|
||||||
|
The tool has four commands:
|
||||||
|
|
||||||
|
- `add-cred`: add or replace credentials for one provider and one VPN type in the encrypted store `credentials`;
|
||||||
|
- `delete-cred`: remove credentials for one provider and one VPN type from the encrypted store `credentials`;
|
||||||
|
- `dump-cred`: print credentials for one provider and one VPN type from the encrypted store `credentials`;
|
||||||
|
- `run`: decrypt credentials on demand, build the required Gluetun environment variables, and run a `qmcgaw/gluetun` container.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Go installed locally
|
||||||
|
- Docker installed and a daemon available to the Docker client
|
||||||
|
- an interactive terminal, since the tool prompts for passwords without echoing them
|
||||||
|
|
||||||
|
The Docker client is created from the standard Docker environment, so settings such as `DOCKER_HOST` are honored.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
### Add credentials
|
||||||
|
|
||||||
|
Add one credential entry to the encrypted store:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run ./cmd/main.go add-cred protonvpn openvpn
|
||||||
|
go run ./cmd/main.go add-cred mullvad wireguard
|
||||||
|
```
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
|
||||||
|
- if `credentials` does not exist yet, `add-cred` asks for a new credentials password and creates the encrypted store;
|
||||||
|
- if `credentials` already exists, `add-cred` asks for the existing password first, decrypts the store, updates it, and writes it back encrypted;
|
||||||
|
- sensitive fields are read from stdin without echo.
|
||||||
|
|
||||||
|
Prompted values depend on the VPN type:
|
||||||
|
|
||||||
|
- `openvpn`: username and password
|
||||||
|
- `wireguard`: private key, optional address, optional preshared key
|
||||||
|
|
||||||
|
Running `add-cred` again for the same provider and VPN type replaces the existing values for that entry.
|
||||||
|
|
||||||
|
### Delete credentials
|
||||||
|
|
||||||
|
Remove one credential entry from the encrypted store:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run ./cmd/main.go delete-cred protonvpn openvpn
|
||||||
|
```
|
||||||
|
|
||||||
|
This asks for the credentials password first, decrypts the store, removes the requested provider and VPN type, and writes the store back encrypted.
|
||||||
|
|
||||||
|
### Dump credentials
|
||||||
|
|
||||||
|
Print one credential entry from the encrypted store:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run ./cmd/main.go dump-cred protonvpn openvpn
|
||||||
|
```
|
||||||
|
|
||||||
|
This asks for the credentials password first and then prints the selected provider and VPN type values.
|
||||||
|
|
||||||
|
### Container run
|
||||||
|
|
||||||
|
Run a container using the image `qmcgaw/gluetun` and the encrypted credentials with the `run` command.
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run ./cmd/main.go run mullvad wireguard
|
||||||
|
go run ./cmd/main.go run protonvpn wireguard -e PORT_FORWARDING=on -p 8000:8000/tcp
|
||||||
|
```
|
||||||
|
|
||||||
|
You will be prompted for the credentials password, the file `credentials` will be decrypted in memory, and the container will be started.
|
||||||
|
|
||||||
|
The following environment variables are always added by the tool:
|
||||||
|
|
||||||
|
- `VPN_SERVICE_PROVIDER=<provider>`
|
||||||
|
- `VPN_TYPE=<vpn-type>`
|
||||||
|
- `LOG_LEVEL=debug`
|
||||||
|
|
||||||
|
The tool also adds `NET_ADMIN` to the container capabilities by default.
|
||||||
|
|
||||||
|
## Credential model
|
||||||
|
|
||||||
|
Internally, the encrypted file stores a binary-encoded map keyed by provider name. Each provider can define `openvpn`, `wireguard`, or both.
|
||||||
|
|
||||||
|
Conceptually, the stored data looks like this:
|
||||||
|
|
||||||
|
- provider `mullvad`: contains `wireguard`
|
||||||
|
- provider `protonvpn`: contains `wireguard`
|
||||||
|
- provider `protonvpn`: contains `openvpn`
|
||||||
|
|
||||||
|
You do not edit this directly. It is stored as encrypted binary data in `credentials`.
|
||||||
|
|
||||||
|
### OpenVPN fields
|
||||||
|
|
||||||
|
- `username` is required;
|
||||||
|
- `password` is required;
|
||||||
|
|
||||||
|
At runtime these map to:
|
||||||
|
|
||||||
|
- `OPENVPN_USER`
|
||||||
|
- `OPENVPN_PASSWORD`
|
||||||
|
|
||||||
|
### WireGuard fields
|
||||||
|
|
||||||
|
- `private_key` is required and must be a valid WireGuard private key;
|
||||||
|
- `address` is optional and must be a valid network prefix if set;
|
||||||
|
- `preshared_key` is optional and must be a valid WireGuard key if set.
|
||||||
|
|
||||||
|
At runtime these map to:
|
||||||
|
|
||||||
|
- `WIREGUARD_PRIVATE_KEY`
|
||||||
|
- `WIREGUARD_ADDRESSES` when `address` is set
|
||||||
|
- `WIREGUARD_PRESHARED_KEY` when `preshared_key` is set
|
||||||
|
|
||||||
|
## Supported extra Docker flags
|
||||||
|
|
||||||
|
The `run` command only accepts a focused subset of Docker-style runtime flags. Unsupported flags return an error.
|
||||||
|
|
||||||
|
Supported flags:
|
||||||
|
|
||||||
|
- `-e`, `--env KEY=VALUE`
|
||||||
|
- `-v`, `--volume SOURCE:TARGET[:mode]`
|
||||||
|
- `-p`, `--publish HOSTPORT:CONTAINERPORT[/proto]`
|
||||||
|
- `--dns IP`
|
||||||
|
- `--device SPEC`
|
||||||
|
- `--label KEY=VALUE`
|
||||||
|
- `--cap-add CAPABILITY`
|
||||||
|
|
||||||
|
## Signals and shutdown
|
||||||
|
|
||||||
|
While the container is running:
|
||||||
|
|
||||||
|
- the first `Ctrl+C` requests a graceful stop with a 5 second timeout;
|
||||||
|
- the second `Ctrl+C` sends a kill signal to the container;
|
||||||
|
- a further interrupt exits the tool immediately.
|
||||||
|
|
||||||
|
## Notes and limitations
|
||||||
|
|
||||||
|
- The container image is fixed to `qmcgaw/gluetun`.
|
||||||
|
- The container name is fixed to `gluetun`.
|
||||||
|
- Credentials are decrypted in memory only during execution.
|
||||||
|
- If the requested provider or VPN type is not present in the encrypted credentials file, the command fails with an explicit error.
|
||||||
|
- The encrypted credential store file is named `credentials`.
|
||||||
|
- This tool is intended for local development convenience, not as a general replacement for `docker run`.
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/devrun/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
const minArgs = 2
|
||||||
|
if len(os.Args) < minArgs {
|
||||||
|
printUsage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch os.Args[1] {
|
||||||
|
case "add-cred":
|
||||||
|
const addCredMinArgs = 4
|
||||||
|
if len(os.Args) < addCredMinArgs {
|
||||||
|
fmt.Fprintf(os.Stderr,
|
||||||
|
`Usage: %s add-cred <provider> <vpn-type>
|
||||||
|
Example: %s add-cred protonvpn wireguard`, os.Args[0], os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
provider := os.Args[2]
|
||||||
|
vpnType := os.Args[3]
|
||||||
|
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||||
|
return internal.AddCredential(ctx, provider, vpnType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "add-cred failed:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
case "delete-cred":
|
||||||
|
const deleteCredMinArgs = 4
|
||||||
|
if len(os.Args) < deleteCredMinArgs {
|
||||||
|
fmt.Fprintf(os.Stderr,
|
||||||
|
`Usage: %s delete-cred <provider> <vpn-type>
|
||||||
|
Example: %s delete-cred protonvpn openvpn`, os.Args[0], os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
provider := os.Args[2]
|
||||||
|
vpnType := os.Args[3]
|
||||||
|
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||||
|
return internal.DeleteCredential(ctx, provider, vpnType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "delete-cred failed:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
case "dump-cred":
|
||||||
|
const dumpCredMinArgs = 4
|
||||||
|
if len(os.Args) < dumpCredMinArgs {
|
||||||
|
fmt.Fprintf(os.Stderr,
|
||||||
|
`Usage: %s dump-cred <provider> <vpn-type>
|
||||||
|
Example: %s dump-cred protonvpn wireguard`, os.Args[0], os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
provider := os.Args[2]
|
||||||
|
vpnType := os.Args[3]
|
||||||
|
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||||
|
return internal.DumpCredential(ctx, provider, vpnType)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "dump-cred failed:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
case "run":
|
||||||
|
const runMinArgs = 4
|
||||||
|
if len(os.Args) < runMinArgs {
|
||||||
|
fmt.Fprintf(os.Stderr,
|
||||||
|
`Usage: %s run <provider> <vpn-type> [extra docker flags...]
|
||||||
|
Example: %s run mullvad wireguard -e SERVER_COUNTRIES=USA`, os.Args[0], os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
provider := os.Args[2]
|
||||||
|
vpnType := os.Args[3]
|
||||||
|
extraArgs := os.Args[4:]
|
||||||
|
err := runWithSignals(func(ctx context.Context, forceKill <-chan struct{}) error {
|
||||||
|
return internal.Run(ctx, provider, vpnType, extraArgs, forceKill)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "run failed:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Fprintln(os.Stderr, "unknown command:", os.Args[1])
|
||||||
|
printUsage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printUsage() {
|
||||||
|
fmt.Fprintf(os.Stderr, `Usage: %s <command> [args...]
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
add-cred <provider> <vpn-type>
|
||||||
|
Add or replace credentials in the encrypted credentials store.
|
||||||
|
delete-cred <provider> <vpn-type>
|
||||||
|
Delete credentials from the encrypted credentials store.
|
||||||
|
dump-cred <provider> <vpn-type>
|
||||||
|
Print credentials for a provider and VPN type pair.
|
||||||
|
run <provider> <vpn-type> [flags...]
|
||||||
|
Decrypt credentials and run a Gluetun container.
|
||||||
|
Extra flags (e.g. -e PORT_FORWARDING=on) are passed to docker run.`,
|
||||||
|
os.Args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWithSignals(runFn func(ctx context.Context, forceKill <-chan struct{}) error) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
const signalBufferSize = 3
|
||||||
|
sigCh := make(chan os.Signal, signalBufferSize)
|
||||||
|
signal.Notify(sigCh, os.Interrupt)
|
||||||
|
defer signal.Stop(sigCh)
|
||||||
|
|
||||||
|
forceKill := make(chan struct{})
|
||||||
|
stopSignalLoop := make(chan struct{})
|
||||||
|
signalLoopDone := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(signalLoopDone)
|
||||||
|
|
||||||
|
const secondInterrupt = 2
|
||||||
|
interruptCount := uint(0)
|
||||||
|
forceKillSent := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopSignalLoop:
|
||||||
|
return
|
||||||
|
case <-sigCh:
|
||||||
|
interruptCount++
|
||||||
|
switch interruptCount {
|
||||||
|
case 1:
|
||||||
|
cancel()
|
||||||
|
case secondInterrupt:
|
||||||
|
if !forceKillSent {
|
||||||
|
close(forceKill)
|
||||||
|
forceKillSent = true
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := runFn(ctx, forceKill)
|
||||||
|
close(stopSignalLoop)
|
||||||
|
<-signalLoopDone
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
module github.com/qdm12/gluetun/devrun
|
||||||
|
|
||||||
|
go 1.25.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/docker/docker v28.5.2+incompatible
|
||||||
|
github.com/docker/go-connections v0.7.0
|
||||||
|
github.com/opencontainers/image-spec v1.1.1
|
||||||
|
golang.org/x/crypto v0.50.0
|
||||||
|
golang.org/x/term v0.42.0
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||||
|
github.com/moby/term v0.5.2 // indirect
|
||||||
|
github.com/morikuni/aec v1.1.0 // indirect
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||||
|
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||||
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
|
golang.org/x/time v0.15.0 // indirect
|
||||||
|
gotest.tools/v3 v3.5.2 // indirect
|
||||||
|
)
|
||||||
+105
@@ -0,0 +1,105 @@
|
|||||||
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
|
||||||
|
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf0mN39c=
|
||||||
|
github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||||
|
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||||
|
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||||
|
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||||
|
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||||
|
github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ=
|
||||||
|
github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||||
|
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||||
|
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak=
|
||||||
|
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||||
|
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||||
|
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||||
|
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
|
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||||
|
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||||
|
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||||
|
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||||
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
@@ -0,0 +1,240 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const credentialsFilename = "credentials"
|
||||||
|
|
||||||
|
const (
|
||||||
|
vpnTypeOpenVPN = "openvpn"
|
||||||
|
vpnTypeWireGuard = "wireguard"
|
||||||
|
)
|
||||||
|
|
||||||
|
type providerCredentials struct {
|
||||||
|
OpenVPN *openvpnCredentials
|
||||||
|
WireGuard *wireguardCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
type openvpnCredentials struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
type wireguardCredentials struct {
|
||||||
|
PrivateKey string
|
||||||
|
Address string
|
||||||
|
PresharedKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCredentials(data []byte) (map[string]providerCredentials, error) {
|
||||||
|
credentials := make(map[string]providerCredentials)
|
||||||
|
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&credentials)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decoding credentials: %w", err)
|
||||||
|
}
|
||||||
|
return credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalCredentials(credentials map[string]providerCredentials) ([]byte, error) {
|
||||||
|
buffer := bytes.NewBuffer(nil)
|
||||||
|
err := gob.NewEncoder(buffer).Encode(credentials)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encoding credentials: %w", err)
|
||||||
|
}
|
||||||
|
return buffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCredentials(providerNameToCredentials map[string]providerCredentials) error {
|
||||||
|
for provider, credentials := range providerNameToCredentials {
|
||||||
|
if credentials.OpenVPN == nil && credentials.WireGuard == nil {
|
||||||
|
return fmt.Errorf("provider %q has no openvpn or wireguard credentials", provider)
|
||||||
|
}
|
||||||
|
if credentials.OpenVPN != nil {
|
||||||
|
err := validateOpenvpnCredentials(provider, credentials.OpenVPN)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if credentials.WireGuard != nil {
|
||||||
|
err := validateWireguardCredentials(provider, credentials.WireGuard)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOpenvpnCredentials(provider string, creds *openvpnCredentials) error {
|
||||||
|
switch {
|
||||||
|
case creds.Username == "":
|
||||||
|
return fmt.Errorf("provider %q openvpn credentials are missing the username", provider)
|
||||||
|
case creds.Password == "":
|
||||||
|
return fmt.Errorf("provider %q openvpn credentials are missing the password", provider)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateWireguardCredentials(provider string, creds *wireguardCredentials) error {
|
||||||
|
if creds.PrivateKey == "" {
|
||||||
|
return fmt.Errorf("provider %q wireguard credentials are missing the private key", provider)
|
||||||
|
} else if _, err := wgtypes.ParseKey(creds.PrivateKey); err != nil {
|
||||||
|
return fmt.Errorf("provider %q wireguard credentials have an invalid private key: %w", provider, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if creds.Address != "" {
|
||||||
|
_, err := netip.ParsePrefix(creds.Address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("provider %q wireguard credentials have an invalid address %q: %w", provider, creds.Address, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if creds.PresharedKey != "" {
|
||||||
|
if _, err := wgtypes.ParseKey(creds.PresharedKey); err != nil {
|
||||||
|
return fmt.Errorf("provider %q wireguard credentials have an invalid preshared key: %w", provider, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupCredentials(credentials map[string]providerCredentials, provider, vpnType string) ([]string, error) {
|
||||||
|
providerCreds, exists := credentials[provider]
|
||||||
|
if !exists {
|
||||||
|
existing := slices.Collect(maps.Keys(credentials))
|
||||||
|
return nil, fmt.Errorf("no credentials found for provider %q, available providers are: %s",
|
||||||
|
provider, strings.Join(existing, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch vpnType {
|
||||||
|
case vpnTypeWireGuard:
|
||||||
|
if providerCreds.WireGuard == nil {
|
||||||
|
return nil, fmt.Errorf("no wireguard credentials found for provider %q", provider)
|
||||||
|
}
|
||||||
|
return buildWireGuardEnv(providerCreds.WireGuard), nil
|
||||||
|
case vpnTypeOpenVPN:
|
||||||
|
if providerCreds.OpenVPN == nil {
|
||||||
|
return nil, fmt.Errorf("no openvpn credentials found for provider %q", provider)
|
||||||
|
}
|
||||||
|
return buildOpenvpnEnv(providerCreds.OpenVPN), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildWireGuardEnv(creds *wireguardCredentials) []string {
|
||||||
|
envVars := []string{
|
||||||
|
"WIREGUARD_PRIVATE_KEY=" + creds.PrivateKey,
|
||||||
|
}
|
||||||
|
if creds.Address != "" {
|
||||||
|
envVars = append(envVars, "WIREGUARD_ADDRESSES="+creds.Address)
|
||||||
|
}
|
||||||
|
if creds.PresharedKey != "" {
|
||||||
|
envVars = append(envVars, "WIREGUARD_PRESHARED_KEY="+creds.PresharedKey)
|
||||||
|
}
|
||||||
|
return envVars
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenvpnEnv(creds *openvpnCredentials) []string {
|
||||||
|
return []string{
|
||||||
|
"OPENVPN_USER=" + creds.Username,
|
||||||
|
"OPENVPN_PASSWORD=" + creds.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addCredential(credentials map[string]providerCredentials, provider, vpnType string,
|
||||||
|
openvpnCredentials *openvpnCredentials, wireguardCredentials *wireguardCredentials,
|
||||||
|
) error {
|
||||||
|
providerCredentials := credentials[provider]
|
||||||
|
|
||||||
|
switch vpnType {
|
||||||
|
case vpnTypeOpenVPN:
|
||||||
|
providerCredentials.OpenVPN = openvpnCredentials
|
||||||
|
case vpnTypeWireGuard:
|
||||||
|
providerCredentials.WireGuard = wireguardCredentials
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials[provider] = providerCredentials
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteCredential(credentials map[string]providerCredentials, provider, vpnType string) error {
|
||||||
|
providerCredentials, exists := credentials[provider]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("provider %q does not exist", provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch vpnType {
|
||||||
|
case vpnTypeOpenVPN:
|
||||||
|
if providerCredentials.OpenVPN == nil {
|
||||||
|
return fmt.Errorf("provider %q has no openvpn credentials", provider)
|
||||||
|
}
|
||||||
|
providerCredentials.OpenVPN = nil
|
||||||
|
case vpnTypeWireGuard:
|
||||||
|
if providerCredentials.WireGuard == nil {
|
||||||
|
return fmt.Errorf("provider %q has no wireguard credentials", provider)
|
||||||
|
}
|
||||||
|
providerCredentials.WireGuard = nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerCredentials.OpenVPN == nil && providerCredentials.WireGuard == nil {
|
||||||
|
delete(credentials, provider)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials[provider] = providerCredentials
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatCredentialForDump(provider, vpnType string,
|
||||||
|
providerCredentials providerCredentials,
|
||||||
|
) (output string, err error) {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
builder.WriteString("provider: ")
|
||||||
|
builder.WriteString(provider)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.WriteString("vpn_type: ")
|
||||||
|
builder.WriteString(vpnType)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
|
||||||
|
switch vpnType {
|
||||||
|
case vpnTypeOpenVPN:
|
||||||
|
if providerCredentials.OpenVPN == nil {
|
||||||
|
return "", fmt.Errorf("no openvpn credentials found for provider %q", provider)
|
||||||
|
}
|
||||||
|
builder.WriteString("username: ")
|
||||||
|
builder.WriteString(providerCredentials.OpenVPN.Username)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.WriteString("password: ")
|
||||||
|
builder.WriteString(providerCredentials.OpenVPN.Password)
|
||||||
|
case vpnTypeWireGuard:
|
||||||
|
if providerCredentials.WireGuard == nil {
|
||||||
|
return "", fmt.Errorf("no wireguard credentials found for provider %q", provider)
|
||||||
|
}
|
||||||
|
builder.WriteString("private_key: ")
|
||||||
|
builder.WriteString(providerCredentials.WireGuard.PrivateKey)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.WriteString("address: ")
|
||||||
|
builder.WriteString(providerCredentials.WireGuard.Address)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.WriteString("preshared_key: ")
|
||||||
|
builder.WriteString(providerCredentials.WireGuard.PresharedKey)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,343 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_addCredential(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
initialCredentials map[string]providerCredentials
|
||||||
|
provider string
|
||||||
|
vpnType string
|
||||||
|
openvpnCredentials *openvpnCredentials
|
||||||
|
wireguardCreds *wireguardCredentials
|
||||||
|
expectedLength int
|
||||||
|
expectedOpenVPN bool
|
||||||
|
expectedWireGuard bool
|
||||||
|
}{
|
||||||
|
"adds_openvpn_credentials": {
|
||||||
|
initialCredentials: map[string]providerCredentials{},
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: "openvpn",
|
||||||
|
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
expectedLength: 1,
|
||||||
|
expectedOpenVPN: true,
|
||||||
|
},
|
||||||
|
"adds_wireguard_credentials": {
|
||||||
|
initialCredentials: map[string]providerCredentials{},
|
||||||
|
provider: "mullvad",
|
||||||
|
vpnType: "wireguard",
|
||||||
|
wireguardCreds: &wireguardCredentials{
|
||||||
|
PrivateKey: wireguardPrivateKey.String(),
|
||||||
|
Address: "10.0.0.2/32",
|
||||||
|
},
|
||||||
|
expectedLength: 1,
|
||||||
|
expectedWireGuard: true,
|
||||||
|
},
|
||||||
|
"preserves_other_protocol": {
|
||||||
|
initialCredentials: map[string]providerCredentials{
|
||||||
|
"protonvpn": {
|
||||||
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: "openvpn",
|
||||||
|
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
expectedLength: 1,
|
||||||
|
expectedOpenVPN: true,
|
||||||
|
expectedWireGuard: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
credentials := cloneCredentials(testCase.initialCredentials)
|
||||||
|
|
||||||
|
err := addCredential(credentials, testCase.provider, testCase.vpnType,
|
||||||
|
testCase.openvpnCredentials, testCase.wireguardCreds)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providerCredentials := credentials[testCase.provider]
|
||||||
|
if len(credentials) != testCase.expectedLength {
|
||||||
|
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
||||||
|
}
|
||||||
|
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
||||||
|
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
||||||
|
}
|
||||||
|
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
||||||
|
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_deleteCredential(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
initialCredentials map[string]providerCredentials
|
||||||
|
provider string
|
||||||
|
vpnType string
|
||||||
|
expectedLength int
|
||||||
|
expectedOpenVPN bool
|
||||||
|
expectedWireGuard bool
|
||||||
|
}{
|
||||||
|
"deletes_openvpn_only": {
|
||||||
|
initialCredentials: map[string]providerCredentials{
|
||||||
|
"protonvpn": {
|
||||||
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: "openvpn",
|
||||||
|
expectedLength: 1,
|
||||||
|
expectedWireGuard: true,
|
||||||
|
},
|
||||||
|
"deletes_last_protocol_and_provider": {
|
||||||
|
initialCredentials: map[string]providerCredentials{
|
||||||
|
"protonvpn": {
|
||||||
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: "openvpn",
|
||||||
|
expectedLength: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
credentials := cloneCredentials(testCase.initialCredentials)
|
||||||
|
|
||||||
|
err := deleteCredential(credentials, testCase.provider, testCase.vpnType)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(credentials) != testCase.expectedLength {
|
||||||
|
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
||||||
|
}
|
||||||
|
|
||||||
|
providerCredentials, exists := credentials[testCase.provider]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
||||||
|
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
||||||
|
}
|
||||||
|
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
||||||
|
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validateCredentials(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
credentials map[string]providerCredentials
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
"both_protocols_valid": {
|
||||||
|
credentials: map[string]providerCredentials{
|
||||||
|
"protonvpn": {
|
||||||
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"invalid_wireguard_when_both_present": {
|
||||||
|
credentials: map[string]providerCredentials{
|
||||||
|
"protonvpn": {
|
||||||
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
WireGuard: &wireguardCredentials{PrivateKey: "invalid"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := validateCredentials(testCase.credentials)
|
||||||
|
if testCase.wantError && err == nil {
|
||||||
|
t.Fatal("expected an error but got nil")
|
||||||
|
}
|
||||||
|
if !testCase.wantError && err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_marshalLoadCredentials(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := map[string]providerCredentials{
|
||||||
|
"mullvad": {
|
||||||
|
WireGuard: &wireguardCredentials{
|
||||||
|
PrivateKey: wireguardPrivateKey.String(),
|
||||||
|
Address: "10.0.0.2/32",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"protonvpn": {
|
||||||
|
OpenVPN: &openvpnCredentials{
|
||||||
|
Username: "user",
|
||||||
|
Password: "pass",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := marshalCredentials(credentials)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := loadCredentials(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected load error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded) != len(credentials) {
|
||||||
|
t.Fatalf("expected %d providers, got %d", len(credentials), len(decoded))
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded["mullvad"].WireGuard == nil {
|
||||||
|
t.Fatal("expected mullvad wireguard credentials to be present")
|
||||||
|
}
|
||||||
|
if decoded["protonvpn"].OpenVPN == nil {
|
||||||
|
t.Fatal("expected protonvpn openvpn credentials to be present")
|
||||||
|
}
|
||||||
|
if decoded["protonvpn"].OpenVPN.Password != "pass" {
|
||||||
|
t.Fatalf("expected protonvpn password %q, got %q", "pass", decoded["protonvpn"].OpenVPN.Password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_formatCredentialForDump(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
provider string
|
||||||
|
vpnType string
|
||||||
|
providerCredentials providerCredentials
|
||||||
|
expectedOutput string
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
"openvpn": {
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: vpnTypeOpenVPN,
|
||||||
|
providerCredentials: providerCredentials{
|
||||||
|
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||||
|
},
|
||||||
|
expectedOutput: "provider: protonvpn\n" +
|
||||||
|
"vpn_type: openvpn\n" +
|
||||||
|
"username: user\n" +
|
||||||
|
"password: pass",
|
||||||
|
},
|
||||||
|
"wireguard": {
|
||||||
|
provider: "mullvad",
|
||||||
|
vpnType: vpnTypeWireGuard,
|
||||||
|
providerCredentials: providerCredentials{
|
||||||
|
WireGuard: &wireguardCredentials{
|
||||||
|
PrivateKey: "private",
|
||||||
|
Address: "10.0.0.2/32",
|
||||||
|
PresharedKey: "preshared",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedOutput: "provider: mullvad\n" +
|
||||||
|
"vpn_type: wireguard\n" +
|
||||||
|
"private_key: private\n" +
|
||||||
|
"address: 10.0.0.2/32\n" +
|
||||||
|
"preshared_key: preshared",
|
||||||
|
},
|
||||||
|
"missing_protocol": {
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: vpnTypeOpenVPN,
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
"unknown_protocol": {
|
||||||
|
provider: "protonvpn",
|
||||||
|
vpnType: "other",
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
output, err := formatCredentialForDump(
|
||||||
|
testCase.provider,
|
||||||
|
testCase.vpnType,
|
||||||
|
testCase.providerCredentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
if testCase.wantError {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error but got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if output != testCase.expectedOutput {
|
||||||
|
t.Fatalf("expected output %q, got %q", testCase.expectedOutput, output)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneCredentials(credentials map[string]providerCredentials) map[string]providerCredentials {
|
||||||
|
clone := make(map[string]providerCredentials, len(credentials))
|
||||||
|
for provider, providerCredentials := range credentials {
|
||||||
|
copied := providerCredentials
|
||||||
|
if providerCredentials.OpenVPN != nil {
|
||||||
|
openvpnCredentials := *providerCredentials.OpenVPN
|
||||||
|
copied.OpenVPN = &openvpnCredentials
|
||||||
|
}
|
||||||
|
if providerCredentials.WireGuard != nil {
|
||||||
|
wireguardCredentials := *providerCredentials.WireGuard
|
||||||
|
copied.WireGuard = &wireguardCredentials
|
||||||
|
}
|
||||||
|
clone[provider] = copied
|
||||||
|
}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
@@ -0,0 +1,521 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/scrypt"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encryption format: [16-byte salt][12-byte nonce][AES-256-GCM ciphertext+tag]
|
||||||
|
// Key derivation: scrypt(password, salt, N=32768, r=8, p=1, keyLen=32)
|
||||||
|
|
||||||
|
const (
|
||||||
|
saltSize = 16
|
||||||
|
nonceSize = 12
|
||||||
|
keySize = 32
|
||||||
|
scryptN = 32768
|
||||||
|
scryptR = 8
|
||||||
|
scryptP = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddCredential prompts for credential values and stores them in the encrypted credentials file.
|
||||||
|
func AddCredential(ctx context.Context, provider, vpnType string) error {
|
||||||
|
credentials, password, err := loadCredentialsForMutation(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = promptAndAddCredential(ctx, credentials, provider, vpnType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateCredentials(credentials)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeEncryptedCredentials(credentials, password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(
|
||||||
|
"Credentials for provider %q and vpn type %q saved to %s\n",
|
||||||
|
provider, vpnType, credentialsFilename,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteCredential removes credentials for a provider and VPN type
|
||||||
|
// from the encrypted credentials file.
|
||||||
|
func DeleteCredential(ctx context.Context, provider, vpnType string) error {
|
||||||
|
credentials, password, err := loadExistingCredentialsForMutation(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = deleteCredential(credentials, provider, vpnType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeEncryptedCredentials(credentials, password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(
|
||||||
|
"Credentials for provider %q and vpn type %q removed from %s\n",
|
||||||
|
provider, vpnType, credentialsFilename,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DumpCredential decrypts the credential store and prints one provider/vpn-type entry.
|
||||||
|
func DumpCredential(ctx context.Context, provider, vpnType string) error {
|
||||||
|
credentials, err := decryptCredentials(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
providerCredentials, exists := credentials[provider]
|
||||||
|
if !exists {
|
||||||
|
existingProviders := slices.Collect(maps.Keys(credentials))
|
||||||
|
return fmt.Errorf("provider %q does not exist, available providers are: %s",
|
||||||
|
provider, strings.Join(existingProviders, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := formatCredentialForDump(provider, vpnType, providerCredentials)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(output)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptCredentials reads the encrypted credentials file,
|
||||||
|
// prompts for a password, and returns the decrypted credentials.
|
||||||
|
func decryptCredentials(ctx context.Context) (map[string]providerCredentials, error) {
|
||||||
|
password, err := readSecret(ctx, "Enter credentials password: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := decryptCredentialsFile(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials, err := loadCredentials(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCredentialsForMutation(ctx context.Context) (
|
||||||
|
credentials map[string]providerCredentials,
|
||||||
|
password []byte,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
_, err = os.Stat(credentialsFilename)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
password, err = readPasswordConfirmed(ctx,
|
||||||
|
"Enter new credentials password: ",
|
||||||
|
"Confirm new credentials password: ",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||||
|
}
|
||||||
|
return make(map[string]providerCredentials), password, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err = readSecret(ctx, "Enter credentials password: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := decryptCredentialsFile(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials, err = loadCredentials(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("loading credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials, password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadExistingCredentialsForMutation(ctx context.Context) (
|
||||||
|
credentials map[string]providerCredentials,
|
||||||
|
password []byte,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
_, err = os.Stat(credentialsFilename)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil, fmt.Errorf("%s does not exist", credentialsFilename)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err = readSecret(ctx, "Enter credentials password: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := decryptCredentialsFile(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials, err = loadCredentials(plaintext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("loading credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials, password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptAndAddCredential(
|
||||||
|
ctx context.Context,
|
||||||
|
credentials map[string]providerCredentials,
|
||||||
|
provider, vpnType string,
|
||||||
|
) error {
|
||||||
|
switch vpnType {
|
||||||
|
case vpnTypeOpenVPN:
|
||||||
|
username, err := readLine(ctx, "OpenVPN username: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading username: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err := readSecret(ctx, "OpenVPN password: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openvpnCredentials := &openvpnCredentials{
|
||||||
|
Username: username,
|
||||||
|
Password: string(password),
|
||||||
|
}
|
||||||
|
err = validateOpenvpnCredentials(provider, openvpnCredentials)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return addCredential(credentials, provider, vpnType, openvpnCredentials, nil)
|
||||||
|
|
||||||
|
case vpnTypeWireGuard:
|
||||||
|
privateKey, err := readSecret(ctx, "WireGuard private key: ", false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
address, err := readLine(ctx, "WireGuard address (optional): ", true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
presharedKey, err := readSecret(
|
||||||
|
ctx,
|
||||||
|
"WireGuard preshared key (optional): ",
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading preshared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wireguardCredentials := &wireguardCredentials{
|
||||||
|
PrivateKey: string(privateKey),
|
||||||
|
Address: address,
|
||||||
|
PresharedKey: string(presharedKey),
|
||||||
|
}
|
||||||
|
err = validateWireguardCredentials(provider, wireguardCredentials)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return addCredential(credentials, provider, vpnType, nil, wireguardCredentials)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeEncryptedCredentials(
|
||||||
|
credentials map[string]providerCredentials,
|
||||||
|
password []byte,
|
||||||
|
) error {
|
||||||
|
plaintext, err := marshalCredentials(credentials)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encoding credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted, err := encryptData(plaintext, password)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encrypting credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const filePerms = 0o600
|
||||||
|
err = os.WriteFile(credentialsFilename, encrypted, filePerms)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing %s: %w", credentialsFilename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptCredentialsFile(password []byte) ([]byte, error) {
|
||||||
|
encryptedData, err := os.ReadFile(credentialsFilename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading %s: %w", credentialsFilename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := decryptData(encryptedData, password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypting credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSecret(ctx context.Context, prompt string, allowEmpty bool) ([]byte, error) {
|
||||||
|
fmt.Print(prompt)
|
||||||
|
|
||||||
|
passwordFD, err := syscall.Dup(syscall.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println()
|
||||||
|
return nil, fmt.Errorf("duplicating stdin file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var closeFDOnce sync.Once
|
||||||
|
closePasswordFD := func() {
|
||||||
|
closeFDOnce.Do(func() {
|
||||||
|
_ = syscall.Close(passwordFD)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordResult := make(chan struct {
|
||||||
|
password []byte
|
||||||
|
err error
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
password, err := term.ReadPassword(passwordFD)
|
||||||
|
closePasswordFD()
|
||||||
|
result := struct {
|
||||||
|
password []byte
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
password: password,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case passwordResult <- result:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
closePasswordFD()
|
||||||
|
fmt.Println()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case result := <-passwordResult:
|
||||||
|
closePasswordFD()
|
||||||
|
fmt.Println()
|
||||||
|
if result.err != nil {
|
||||||
|
return nil, fmt.Errorf("reading hidden input from terminal: %w", result.err)
|
||||||
|
}
|
||||||
|
if len(result.password) == 0 && !allowEmpty {
|
||||||
|
return nil, fmt.Errorf("value cannot be empty")
|
||||||
|
}
|
||||||
|
return result.password, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLine(ctx context.Context, prompt string, allowEmpty bool) (string, error) {
|
||||||
|
fmt.Print(prompt)
|
||||||
|
|
||||||
|
inputFD, err := syscall.Dup(syscall.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println()
|
||||||
|
return "", fmt.Errorf("duplicating stdin file descriptor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var closeFDOnce sync.Once
|
||||||
|
closeInputFD := func() {
|
||||||
|
closeFDOnce.Do(func() {
|
||||||
|
_ = syscall.Close(inputFD)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
inputResult := make(chan struct {
|
||||||
|
value string
|
||||||
|
err error
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
inputFile := os.NewFile(uintptr(inputFD), "stdin")
|
||||||
|
reader := bufio.NewReader(inputFile)
|
||||||
|
value, err := reader.ReadString('\n')
|
||||||
|
closeInputFD()
|
||||||
|
value = strings.TrimRight(value, "\r\n")
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := struct {
|
||||||
|
value string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
value: value,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case inputResult <- result:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
closeInputFD()
|
||||||
|
fmt.Println()
|
||||||
|
return "", ctx.Err()
|
||||||
|
case result := <-inputResult:
|
||||||
|
closeInputFD()
|
||||||
|
if result.err != nil {
|
||||||
|
return "", fmt.Errorf("reading line from terminal: %w", result.err)
|
||||||
|
}
|
||||||
|
if result.value == "" && !allowEmpty {
|
||||||
|
return "", fmt.Errorf("value cannot be empty")
|
||||||
|
}
|
||||||
|
return result.value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readPasswordConfirmed(
|
||||||
|
ctx context.Context,
|
||||||
|
prompt, confirmationPrompt string,
|
||||||
|
) ([]byte, error) {
|
||||||
|
password, err := readSecret(ctx, prompt, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
confirmation, err := readSecret(ctx, confirmationPrompt, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(password) != string(confirmation) {
|
||||||
|
return nil, fmt.Errorf("passwords do not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveKey(password, salt []byte) ([]byte, error) {
|
||||||
|
key, err := scrypt.Key(password, salt, scryptN, scryptR, scryptP, keySize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("deriving key with scrypt: %w", err)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptData(plaintext, password []byte) ([]byte, error) {
|
||||||
|
salt := make([]byte, saltSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, salt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := deriveKey(password, salt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, nonceSize)
|
||||||
|
_, err = io.ReadFull(rand.Reader, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
|
||||||
|
|
||||||
|
result := make([]byte, 0, saltSize+nonceSize+len(ciphertext))
|
||||||
|
result = append(result, salt...)
|
||||||
|
result = append(result, nonce...)
|
||||||
|
result = append(result, ciphertext...)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptData(data, password []byte) ([]byte, error) {
|
||||||
|
const minSize = saltSize + nonceSize + 16 // 16 is the GCM tag size
|
||||||
|
if len(data) < minSize {
|
||||||
|
return nil, fmt.Errorf("encrypted data too short: %d bytes", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
salt := data[:saltSize]
|
||||||
|
nonce := data[saltSize : saltSize+nonceSize]
|
||||||
|
ciphertext := data[saltSize+nonceSize:]
|
||||||
|
|
||||||
|
key, err := deriveKey(password, salt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypting data (wrong password?): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,351 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"github.com/docker/docker/api/types/mount"
|
||||||
|
"github.com/docker/docker/api/types/network"
|
||||||
|
"github.com/docker/docker/client"
|
||||||
|
"github.com/docker/docker/pkg/stdcopy"
|
||||||
|
"github.com/docker/go-connections/nat"
|
||||||
|
v1 "github.com/opencontainers/image-spec/specs-go/v1"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
type containerOptions struct {
|
||||||
|
env []string
|
||||||
|
binds []string
|
||||||
|
ports nat.PortMap
|
||||||
|
dns []string
|
||||||
|
devices []container.DeviceMapping
|
||||||
|
labels map[string]string
|
||||||
|
capAdd []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run decrypts credentials, builds the container environment, and runs a Gluetun container.
|
||||||
|
// extraArgs is the list of additional flags (e.g. ["-e", "PORT_FORWARDING=on", "-v", "/host:/container"]).
|
||||||
|
func Run(ctx context.Context, provider, vpnType string, extraArgs []string,
|
||||||
|
forceKill <-chan struct{},
|
||||||
|
) error {
|
||||||
|
credentials, err := decryptCredentials(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loading credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
credentialEnvVars, err := lookupCredentials(credentials, provider, vpnType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
extraOpts, err := parseExtraArgs(extraArgs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing extra flags: %w", err)
|
||||||
|
}
|
||||||
|
opts := extraOpts
|
||||||
|
opts.env = append(opts.env,
|
||||||
|
"VPN_SERVICE_PROVIDER="+provider,
|
||||||
|
"VPN_TYPE="+vpnType,
|
||||||
|
"LOG_LEVEL=debug",
|
||||||
|
)
|
||||||
|
opts.env = append(opts.env, credentialEnvVars...)
|
||||||
|
opts.capAdd = append(opts.capAdd, "NET_ADMIN")
|
||||||
|
|
||||||
|
return runContainer(ctx, opts, forceKill)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExtraArgs parses extra arguments and maps them to container options.
|
||||||
|
// Supported flags:
|
||||||
|
//
|
||||||
|
// -e, --env KEY=VALUE - environment variable
|
||||||
|
// -v, --volume SPEC - volume mount (e.g., "/host:/container" or "name:/container")
|
||||||
|
// -p, --publish PORT:PORT - port mapping
|
||||||
|
// --dns IP - DNS server
|
||||||
|
// --device SPEC - device access (e.g., "/dev/net/tun")
|
||||||
|
// --label KEY=VALUE - container label
|
||||||
|
// --cap-add CAPABILITY - add Linux capability (e.g., "SYS_PTRACE")
|
||||||
|
func parseExtraArgs(args []string) (opts containerOptions, err error) { //nolint:gocognit,gocyclo
|
||||||
|
opts = containerOptions{
|
||||||
|
ports: make(nat.PortMap),
|
||||||
|
labels: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
arg := args[i]
|
||||||
|
switch {
|
||||||
|
case arg == "-e" || arg == "--env":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
opts.env = append(opts.env, args[i])
|
||||||
|
case strings.HasPrefix(arg, "-e="):
|
||||||
|
opts.env = append(opts.env, strings.TrimPrefix(arg, "-e="))
|
||||||
|
case strings.HasPrefix(arg, "--env="):
|
||||||
|
opts.env = append(opts.env, strings.TrimPrefix(arg, "--env="))
|
||||||
|
|
||||||
|
case arg == "-v" || arg == "--volume":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
opts.binds = append(opts.binds, args[i])
|
||||||
|
case strings.HasPrefix(arg, "-v="):
|
||||||
|
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "-v="))
|
||||||
|
case strings.HasPrefix(arg, "--volume="):
|
||||||
|
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "--volume="))
|
||||||
|
|
||||||
|
case arg == "-p" || arg == "--publish":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
if err := parsePortMapping(opts.ports, args[i]); err != nil {
|
||||||
|
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(arg, "-p="):
|
||||||
|
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "-p=")); err != nil {
|
||||||
|
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(arg, "--publish="):
|
||||||
|
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "--publish=")); err != nil {
|
||||||
|
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case arg == "--dns":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
opts.dns = append(opts.dns, args[i])
|
||||||
|
case strings.HasPrefix(arg, "--dns="):
|
||||||
|
opts.dns = append(opts.dns, strings.TrimPrefix(arg, "--dns="))
|
||||||
|
|
||||||
|
case arg == "--device":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
parseDeviceMapping(&opts.devices, args[i])
|
||||||
|
case strings.HasPrefix(arg, "--device="):
|
||||||
|
parseDeviceMapping(&opts.devices, strings.TrimPrefix(arg, "--device="))
|
||||||
|
|
||||||
|
case arg == "--label":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
parseLabel(opts.labels, args[i])
|
||||||
|
case strings.HasPrefix(arg, "--label="):
|
||||||
|
parseLabel(opts.labels, strings.TrimPrefix(arg, "--label="))
|
||||||
|
|
||||||
|
case arg == "--cap-add":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
opts.capAdd = append(opts.capAdd, args[i])
|
||||||
|
case strings.HasPrefix(arg, "--cap-add="):
|
||||||
|
opts.capAdd = append(opts.capAdd, strings.TrimPrefix(arg, "--cap-add="))
|
||||||
|
|
||||||
|
default:
|
||||||
|
return opts, fmt.Errorf("unsupported flag %q", arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePortMapping(portMap nat.PortMap, spec string) error {
|
||||||
|
port, bindings, err := nat.ParsePortSpecs([]string{spec})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for p, binding := range bindings {
|
||||||
|
portMap[p] = binding
|
||||||
|
}
|
||||||
|
for p := range port {
|
||||||
|
if _, exists := portMap[p]; !exists {
|
||||||
|
portMap[p] = []nat.PortBinding{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDeviceMapping(devices *[]container.DeviceMapping, spec string) {
|
||||||
|
parts := strings.SplitN(spec, ":", 3) //nolint:mnd
|
||||||
|
pathOnHost := parts[0]
|
||||||
|
pathInContainer := pathOnHost
|
||||||
|
permissions := "rwm"
|
||||||
|
|
||||||
|
if len(parts) >= 2 { //nolint:mnd
|
||||||
|
pathInContainer = parts[1]
|
||||||
|
}
|
||||||
|
if len(parts) >= 3 { //nolint:mnd
|
||||||
|
permissions = parts[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
*devices = append(*devices, container.DeviceMapping{
|
||||||
|
PathOnHost: pathOnHost,
|
||||||
|
PathInContainer: pathInContainer,
|
||||||
|
CgroupPermissions: permissions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLabel(labels map[string]string, kv string) {
|
||||||
|
parts := strings.SplitN(kv, "=", 2) //nolint:mnd
|
||||||
|
key := parts[0]
|
||||||
|
value := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
value = parts[1]
|
||||||
|
}
|
||||||
|
labels[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func runContainer(ctx context.Context, opts containerOptions, forceKill <-chan struct{}) error {
|
||||||
|
dockerClient, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating docker client: %w", err)
|
||||||
|
}
|
||||||
|
defer dockerClient.Close()
|
||||||
|
|
||||||
|
hasTTY := term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
|
||||||
|
containerConfig := &container.Config{
|
||||||
|
Image: "qmcgaw/gluetun",
|
||||||
|
Env: opts.env,
|
||||||
|
Labels: opts.labels,
|
||||||
|
Tty: hasTTY,
|
||||||
|
}
|
||||||
|
|
||||||
|
mounts := make([]mount.Mount, 0, len(opts.binds))
|
||||||
|
for _, bind := range opts.binds {
|
||||||
|
m, err := parseBindMount(bind)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing bind mount %q: %w", bind, err)
|
||||||
|
}
|
||||||
|
mounts = append(mounts, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostConfig := &container.HostConfig{
|
||||||
|
AutoRemove: true,
|
||||||
|
CapAdd: opts.capAdd,
|
||||||
|
Binds: opts.binds,
|
||||||
|
Mounts: mounts,
|
||||||
|
PortBindings: opts.ports,
|
||||||
|
DNS: opts.dns,
|
||||||
|
}
|
||||||
|
hostConfig.Devices = opts.devices
|
||||||
|
|
||||||
|
networkConfig := &network.NetworkingConfig{}
|
||||||
|
|
||||||
|
platform := (*v1.Platform)(nil)
|
||||||
|
|
||||||
|
const containerName = "gluetun"
|
||||||
|
response, err := dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, networkConfig, platform, containerName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating container: %w", err)
|
||||||
|
}
|
||||||
|
for _, warning := range response.Warnings {
|
||||||
|
fmt.Fprintln(os.Stderr, "container creation warning:", warning)
|
||||||
|
}
|
||||||
|
containerID := response.ID
|
||||||
|
|
||||||
|
err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Container started (id: %.12s)\n", containerID)
|
||||||
|
|
||||||
|
streamLogsErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
streamLogsErr <- streamLogs(context.Background(), dockerClient, containerID, hasTTY)
|
||||||
|
}()
|
||||||
|
|
||||||
|
contextDone := ctx.Done()
|
||||||
|
forceKillSignal := forceKill
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-streamLogsErr:
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case <-contextDone:
|
||||||
|
fmt.Fprintln(os.Stderr, "\nReceived interrupt, stopping container (5s timeout)...")
|
||||||
|
err = stopContainer(dockerClient, containerID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "stopping container:", err)
|
||||||
|
}
|
||||||
|
contextDone = nil
|
||||||
|
case <-forceKillSignal:
|
||||||
|
fmt.Fprintln(os.Stderr, "\nReceived second interrupt, killing container...")
|
||||||
|
err = killContainer(dockerClient, containerID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "killing container:", err)
|
||||||
|
}
|
||||||
|
forceKillSignal = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBindMount(bind string) (mount.Mount, error) {
|
||||||
|
parts := strings.SplitN(bind, ":", 3) //nolint:mnd
|
||||||
|
if len(parts) < 2 { //nolint:mnd
|
||||||
|
return mount.Mount{}, fmt.Errorf("invalid bind mount format: %q (expected source:target[:mode])", bind)
|
||||||
|
}
|
||||||
|
|
||||||
|
source := parts[0]
|
||||||
|
target := parts[1]
|
||||||
|
readOnly := len(parts) > 2 && strings.Contains(parts[2], "ro") //nolint:mnd
|
||||||
|
|
||||||
|
return mount.Mount{
|
||||||
|
Type: mount.TypeBind,
|
||||||
|
Source: source,
|
||||||
|
Target: target,
|
||||||
|
ReadOnly: readOnly,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopContainer(dockerClient *client.Client, containerID string) error {
|
||||||
|
const stopTimeout = 5 * time.Second
|
||||||
|
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
|
||||||
|
defer stopCancel()
|
||||||
|
timeoutSeconds := int(stopTimeout.Seconds())
|
||||||
|
return dockerClient.ContainerStop(stopCtx, containerID, container.StopOptions{Timeout: &timeoutSeconds})
|
||||||
|
}
|
||||||
|
|
||||||
|
func killContainer(dockerClient *client.Client, containerID string) error {
|
||||||
|
return dockerClient.ContainerKill(context.Background(), containerID, "KILL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamLogs(ctx context.Context, dockerClient *client.Client, containerID string, hasTTY bool) error {
|
||||||
|
logOptions := container.LogsOptions{
|
||||||
|
ShowStdout: true,
|
||||||
|
ShowStderr: true,
|
||||||
|
Follow: true,
|
||||||
|
Timestamps: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := dockerClient.ContainerLogs(ctx, containerID, logOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting container logs: %w", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
if hasTTY {
|
||||||
|
_, err = io.Copy(os.Stdout, reader)
|
||||||
|
} else {
|
||||||
|
_, err = stdcopy.StdCopy(os.Stdout, os.Stderr, reader)
|
||||||
|
}
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return fmt.Errorf("streaming container logs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user