diff --git a/.dockerignore b/.dockerignore index e8f23047..42307229 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,3 +7,4 @@ Dockerfile LICENSE README.md title.svg +devrun diff --git a/devrun/.gitignore b/devrun/.gitignore new file mode 100644 index 00000000..454b6550 --- /dev/null +++ b/devrun/.gitignore @@ -0,0 +1 @@ +credentials diff --git a/devrun/README.md b/devrun/README.md new file mode 100644 index 00000000..9921d0a1 --- /dev/null +++ b/devrun/README.md @@ -0,0 +1,152 @@ +# devrun + +`devrun` is a small development helper for starting a local `qmcgaw/gluetun` Docker container with provider credentials stored in an encrypted file. + +It solves two practical problems for local development: + +- keeping VPN credentials out of the shell history and out of a plaintext file once setup is complete; +- quickly starting a Gluetun container for a specific provider and VPN type with a small set of extra Docker runtime options. + +The tool has four commands: + +- `add-cred`: add or replace credentials for one provider and one VPN type in the encrypted store `credentials`; +- `delete-cred`: remove credentials for one provider and one VPN type from the encrypted store `credentials`; +- `dump-cred`: print credentials for one provider and one VPN type from the encrypted store `credentials`; +- `run`: decrypt credentials on demand, build the required Gluetun environment variables, and run a `qmcgaw/gluetun` container. + +## Prerequisites + +- Go installed locally +- Docker installed and a daemon available to the Docker client +- an interactive terminal, since the tool prompts for passwords without echoing them + +The Docker client is created from the standard Docker environment, so settings such as `DOCKER_HOST` are honored. + +## Quick start + +### Add credentials + +Add one credential entry to the encrypted store: + +```sh +go run ./cmd/main.go add-cred protonvpn openvpn +go run ./cmd/main.go add-cred mullvad wireguard +``` + +Behavior: + +- if `credentials` does not exist yet, `add-cred` asks for a new credentials password and creates the encrypted store; +- if `credentials` already exists, `add-cred` asks for the existing password first, decrypts the store, updates it, and writes it back encrypted; +- sensitive fields are read from stdin without echo. + +Prompted values depend on the VPN type: + +- `openvpn`: username and password +- `wireguard`: private key, optional address, optional preshared key + +Running `add-cred` again for the same provider and VPN type replaces the existing values for that entry. + +### Delete credentials + +Remove one credential entry from the encrypted store: + +```sh +go run ./cmd/main.go delete-cred protonvpn openvpn +``` + +This asks for the credentials password first, decrypts the store, removes the requested provider and VPN type, and writes the store back encrypted. + +### Dump credentials + +Print one credential entry from the encrypted store: + +```sh +go run ./cmd/main.go dump-cred protonvpn openvpn +``` + +This asks for the credentials password first and then prints the selected provider and VPN type values. + +### Container run + +Run a container using the image `qmcgaw/gluetun` and the encrypted credentials with the `run` command. +For example: + +```sh +go run ./cmd/main.go run mullvad wireguard +go run ./cmd/main.go run protonvpn wireguard -e PORT_FORWARDING=on -p 8000:8000/tcp +``` + +You will be prompted for the credentials password, the file `credentials` will be decrypted in memory, and the container will be started. + +The following environment variables are always added by the tool: + +- `VPN_SERVICE_PROVIDER=` +- `VPN_TYPE=` +- `LOG_LEVEL=debug` + +The tool also adds `NET_ADMIN` to the container capabilities by default. + +## Credential model + +Internally, the encrypted file stores a binary-encoded map keyed by provider name. Each provider can define `openvpn`, `wireguard`, or both. + +Conceptually, the stored data looks like this: + +- provider `mullvad`: contains `wireguard` +- provider `protonvpn`: contains `wireguard` +- provider `protonvpn`: contains `openvpn` + +You do not edit this directly. It is stored as encrypted binary data in `credentials`. + +### OpenVPN fields + +- `username` is required; +- `password` is required; + +At runtime these map to: + +- `OPENVPN_USER` +- `OPENVPN_PASSWORD` + +### WireGuard fields + +- `private_key` is required and must be a valid WireGuard private key; +- `address` is optional and must be a valid network prefix if set; +- `preshared_key` is optional and must be a valid WireGuard key if set. + +At runtime these map to: + +- `WIREGUARD_PRIVATE_KEY` +- `WIREGUARD_ADDRESSES` when `address` is set +- `WIREGUARD_PRESHARED_KEY` when `preshared_key` is set + +## Supported extra Docker flags + +The `run` command only accepts a focused subset of Docker-style runtime flags. Unsupported flags return an error. + +Supported flags: + +- `-e`, `--env KEY=VALUE` +- `-v`, `--volume SOURCE:TARGET[:mode]` +- `-p`, `--publish HOSTPORT:CONTAINERPORT[/proto]` +- `--dns IP` +- `--device SPEC` +- `--label KEY=VALUE` +- `--cap-add CAPABILITY` + +## Signals and shutdown + +While the container is running: + +- the first `Ctrl+C` requests a graceful stop with a 5 second timeout; +- the second `Ctrl+C` sends a kill signal to the container; +- a further interrupt exits the tool immediately. + +## Notes and limitations + +- The container image is fixed to `qmcgaw/gluetun`. +- The container name is fixed to `gluetun`. +- Credentials are decrypted in memory only during execution. +- If the requested provider or VPN type is not present in the encrypted credentials file, the command fails with an explicit error. +- The encrypted credential store file is named `credentials`. +- This tool is intended for local development convenience, not as a general replacement for `docker run`. diff --git a/devrun/cmd/main.go b/devrun/cmd/main.go new file mode 100644 index 00000000..22fd1bc7 --- /dev/null +++ b/devrun/cmd/main.go @@ -0,0 +1,156 @@ +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + + "github.com/qdm12/gluetun/devrun/internal" +) + +func main() { + const minArgs = 2 + if len(os.Args) < minArgs { + printUsage() + os.Exit(1) + } + + switch os.Args[1] { + case "add-cred": + const addCredMinArgs = 4 + if len(os.Args) < addCredMinArgs { + fmt.Fprintf(os.Stderr, + `Usage: %s add-cred +Example: %s add-cred protonvpn wireguard`, os.Args[0], os.Args[0]) + os.Exit(1) + } + provider := os.Args[2] + vpnType := os.Args[3] + err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error { + return internal.AddCredential(ctx, provider, vpnType) + }) + if err != nil { + fmt.Fprintln(os.Stderr, "add-cred failed:", err) + os.Exit(1) + } + case "delete-cred": + const deleteCredMinArgs = 4 + if len(os.Args) < deleteCredMinArgs { + fmt.Fprintf(os.Stderr, + `Usage: %s delete-cred +Example: %s delete-cred protonvpn openvpn`, os.Args[0], os.Args[0]) + os.Exit(1) + } + provider := os.Args[2] + vpnType := os.Args[3] + err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error { + return internal.DeleteCredential(ctx, provider, vpnType) + }) + if err != nil { + fmt.Fprintln(os.Stderr, "delete-cred failed:", err) + os.Exit(1) + } + case "dump-cred": + const dumpCredMinArgs = 4 + if len(os.Args) < dumpCredMinArgs { + fmt.Fprintf(os.Stderr, + `Usage: %s dump-cred +Example: %s dump-cred protonvpn wireguard`, os.Args[0], os.Args[0]) + os.Exit(1) + } + provider := os.Args[2] + vpnType := os.Args[3] + err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error { + return internal.DumpCredential(ctx, provider, vpnType) + }) + if err != nil { + fmt.Fprintln(os.Stderr, "dump-cred failed:", err) + os.Exit(1) + } + case "run": + const runMinArgs = 4 + if len(os.Args) < runMinArgs { + fmt.Fprintf(os.Stderr, + `Usage: %s run [extra docker flags...] +Example: %s run mullvad wireguard -e SERVER_COUNTRIES=USA`, os.Args[0], os.Args[0]) + os.Exit(1) + } + provider := os.Args[2] + vpnType := os.Args[3] + extraArgs := os.Args[4:] + err := runWithSignals(func(ctx context.Context, forceKill <-chan struct{}) error { + return internal.Run(ctx, provider, vpnType, extraArgs, forceKill) + }) + if err != nil { + fmt.Fprintln(os.Stderr, "run failed:", err) + os.Exit(1) + } + default: + fmt.Fprintln(os.Stderr, "unknown command:", os.Args[1]) + printUsage() + os.Exit(1) + } +} + +func printUsage() { + fmt.Fprintf(os.Stderr, `Usage: %s [args...] + +Commands: + add-cred + Add or replace credentials in the encrypted credentials store. + delete-cred + Delete credentials from the encrypted credentials store. + dump-cred + Print credentials for a provider and VPN type pair. + run [flags...] + Decrypt credentials and run a Gluetun container. + Extra flags (e.g. -e PORT_FORWARDING=on) are passed to docker run.`, + os.Args[0]) +} + +func runWithSignals(runFn func(ctx context.Context, forceKill <-chan struct{}) error) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const signalBufferSize = 3 + sigCh := make(chan os.Signal, signalBufferSize) + signal.Notify(sigCh, os.Interrupt) + defer signal.Stop(sigCh) + + forceKill := make(chan struct{}) + stopSignalLoop := make(chan struct{}) + signalLoopDone := make(chan struct{}) + + go func() { + defer close(signalLoopDone) + + const secondInterrupt = 2 + interruptCount := uint(0) + forceKillSent := false + for { + select { + case <-stopSignalLoop: + return + case <-sigCh: + interruptCount++ + switch interruptCount { + case 1: + cancel() + case secondInterrupt: + if !forceKillSent { + close(forceKill) + forceKillSent = true + } + default: + os.Exit(1) + } + } + } + }() + + err := runFn(ctx, forceKill) + close(stopSignalLoop) + <-signalLoopDone + return err +} diff --git a/devrun/go.mod b/devrun/go.mod new file mode 100644 index 00000000..9d27723a --- /dev/null +++ b/devrun/go.mod @@ -0,0 +1,40 @@ +module github.com/qdm12/gluetun/devrun + +go 1.25.0 + +require ( + github.com/docker/docker v28.5.2+incompatible + github.com/docker/go-connections v0.7.0 + github.com/opencontainers/image-spec v1.1.1 + golang.org/x/crypto v0.50.0 + golang.org/x/term v0.42.0 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 +) + +require ( + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/containerd/errdefs v1.0.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/containerd/log v0.1.0 // indirect + github.com/distribution/reference v0.6.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/sys/atomicwriter v0.1.0 // indirect + github.com/moby/term v0.5.2 // indirect + github.com/morikuni/aec v1.1.0 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect + go.opentelemetry.io/otel v1.43.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect + go.opentelemetry.io/otel/metric v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/time v0.15.0 // indirect + gotest.tools/v3 v3.5.2 // indirect +) diff --git a/devrun/go.sum b/devrun/go.sum new file mode 100644 index 00000000..5e3a825e --- /dev/null +++ b/devrun/go.sum @@ -0,0 +1,105 @@ +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= +github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= +github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf0mN39c= +github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= +github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= +github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= +github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= +github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= +github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= +github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ= +github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= +gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= diff --git a/devrun/internal/credentials.go b/devrun/internal/credentials.go new file mode 100644 index 00000000..9534ec3b --- /dev/null +++ b/devrun/internal/credentials.go @@ -0,0 +1,240 @@ +package internal + +import ( + "bytes" + "encoding/gob" + "fmt" + "maps" + "net/netip" + "slices" + "strings" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const credentialsFilename = "credentials" + +const ( + vpnTypeOpenVPN = "openvpn" + vpnTypeWireGuard = "wireguard" +) + +type providerCredentials struct { + OpenVPN *openvpnCredentials + WireGuard *wireguardCredentials +} + +type openvpnCredentials struct { + Username string + Password string +} + +type wireguardCredentials struct { + PrivateKey string + Address string + PresharedKey string +} + +func loadCredentials(data []byte) (map[string]providerCredentials, error) { + credentials := make(map[string]providerCredentials) + err := gob.NewDecoder(bytes.NewReader(data)).Decode(&credentials) + if err != nil { + return nil, fmt.Errorf("decoding credentials: %w", err) + } + return credentials, nil +} + +func marshalCredentials(credentials map[string]providerCredentials) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + err := gob.NewEncoder(buffer).Encode(credentials) + if err != nil { + return nil, fmt.Errorf("encoding credentials: %w", err) + } + return buffer.Bytes(), nil +} + +func validateCredentials(providerNameToCredentials map[string]providerCredentials) error { + for provider, credentials := range providerNameToCredentials { + if credentials.OpenVPN == nil && credentials.WireGuard == nil { + return fmt.Errorf("provider %q has no openvpn or wireguard credentials", provider) + } + if credentials.OpenVPN != nil { + err := validateOpenvpnCredentials(provider, credentials.OpenVPN) + if err != nil { + return err + } + } + if credentials.WireGuard != nil { + err := validateWireguardCredentials(provider, credentials.WireGuard) + if err != nil { + return err + } + } + } + return nil +} + +func validateOpenvpnCredentials(provider string, creds *openvpnCredentials) error { + switch { + case creds.Username == "": + return fmt.Errorf("provider %q openvpn credentials are missing the username", provider) + case creds.Password == "": + return fmt.Errorf("provider %q openvpn credentials are missing the password", provider) + } + return nil +} + +func validateWireguardCredentials(provider string, creds *wireguardCredentials) error { + if creds.PrivateKey == "" { + return fmt.Errorf("provider %q wireguard credentials are missing the private key", provider) + } else if _, err := wgtypes.ParseKey(creds.PrivateKey); err != nil { + return fmt.Errorf("provider %q wireguard credentials have an invalid private key: %w", provider, err) + } + + if creds.Address != "" { + _, err := netip.ParsePrefix(creds.Address) + if err != nil { + return fmt.Errorf("provider %q wireguard credentials have an invalid address %q: %w", provider, creds.Address, err) + } + } + + if creds.PresharedKey != "" { + if _, err := wgtypes.ParseKey(creds.PresharedKey); err != nil { + return fmt.Errorf("provider %q wireguard credentials have an invalid preshared key: %w", provider, err) + } + } + return nil +} + +func lookupCredentials(credentials map[string]providerCredentials, provider, vpnType string) ([]string, error) { + providerCreds, exists := credentials[provider] + if !exists { + existing := slices.Collect(maps.Keys(credentials)) + return nil, fmt.Errorf("no credentials found for provider %q, available providers are: %s", + provider, strings.Join(existing, ", ")) + } + + switch vpnType { + case vpnTypeWireGuard: + if providerCreds.WireGuard == nil { + return nil, fmt.Errorf("no wireguard credentials found for provider %q", provider) + } + return buildWireGuardEnv(providerCreds.WireGuard), nil + case vpnTypeOpenVPN: + if providerCreds.OpenVPN == nil { + return nil, fmt.Errorf("no openvpn credentials found for provider %q", provider) + } + return buildOpenvpnEnv(providerCreds.OpenVPN), nil + default: + return nil, fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType) + } +} + +func buildWireGuardEnv(creds *wireguardCredentials) []string { + envVars := []string{ + "WIREGUARD_PRIVATE_KEY=" + creds.PrivateKey, + } + if creds.Address != "" { + envVars = append(envVars, "WIREGUARD_ADDRESSES="+creds.Address) + } + if creds.PresharedKey != "" { + envVars = append(envVars, "WIREGUARD_PRESHARED_KEY="+creds.PresharedKey) + } + return envVars +} + +func buildOpenvpnEnv(creds *openvpnCredentials) []string { + return []string{ + "OPENVPN_USER=" + creds.Username, + "OPENVPN_PASSWORD=" + creds.Password, + } +} + +func addCredential(credentials map[string]providerCredentials, provider, vpnType string, + openvpnCredentials *openvpnCredentials, wireguardCredentials *wireguardCredentials, +) error { + providerCredentials := credentials[provider] + + switch vpnType { + case vpnTypeOpenVPN: + providerCredentials.OpenVPN = openvpnCredentials + case vpnTypeWireGuard: + providerCredentials.WireGuard = wireguardCredentials + default: + return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType) + } + + credentials[provider] = providerCredentials + return nil +} + +func deleteCredential(credentials map[string]providerCredentials, provider, vpnType string) error { + providerCredentials, exists := credentials[provider] + if !exists { + return fmt.Errorf("provider %q does not exist", provider) + } + + switch vpnType { + case vpnTypeOpenVPN: + if providerCredentials.OpenVPN == nil { + return fmt.Errorf("provider %q has no openvpn credentials", provider) + } + providerCredentials.OpenVPN = nil + case vpnTypeWireGuard: + if providerCredentials.WireGuard == nil { + return fmt.Errorf("provider %q has no wireguard credentials", provider) + } + providerCredentials.WireGuard = nil + default: + return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType) + } + + if providerCredentials.OpenVPN == nil && providerCredentials.WireGuard == nil { + delete(credentials, provider) + return nil + } + + credentials[provider] = providerCredentials + return nil +} + +func formatCredentialForDump(provider, vpnType string, + providerCredentials providerCredentials, +) (output string, err error) { + var builder strings.Builder + + builder.WriteString("provider: ") + builder.WriteString(provider) + builder.WriteString("\n") + builder.WriteString("vpn_type: ") + builder.WriteString(vpnType) + builder.WriteString("\n") + + switch vpnType { + case vpnTypeOpenVPN: + if providerCredentials.OpenVPN == nil { + return "", fmt.Errorf("no openvpn credentials found for provider %q", provider) + } + builder.WriteString("username: ") + builder.WriteString(providerCredentials.OpenVPN.Username) + builder.WriteString("\n") + builder.WriteString("password: ") + builder.WriteString(providerCredentials.OpenVPN.Password) + case vpnTypeWireGuard: + if providerCredentials.WireGuard == nil { + return "", fmt.Errorf("no wireguard credentials found for provider %q", provider) + } + builder.WriteString("private_key: ") + builder.WriteString(providerCredentials.WireGuard.PrivateKey) + builder.WriteString("\n") + builder.WriteString("address: ") + builder.WriteString(providerCredentials.WireGuard.Address) + builder.WriteString("\n") + builder.WriteString("preshared_key: ") + builder.WriteString(providerCredentials.WireGuard.PresharedKey) + default: + return "", fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType) + } + + return builder.String(), nil +} diff --git a/devrun/internal/credentials_test.go b/devrun/internal/credentials_test.go new file mode 100644 index 00000000..c9471c08 --- /dev/null +++ b/devrun/internal/credentials_test.go @@ -0,0 +1,343 @@ +package internal + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func Test_addCredential(t *testing.T) { + t.Parallel() + + wireguardPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + testCases := map[string]struct { + initialCredentials map[string]providerCredentials + provider string + vpnType string + openvpnCredentials *openvpnCredentials + wireguardCreds *wireguardCredentials + expectedLength int + expectedOpenVPN bool + expectedWireGuard bool + }{ + "adds_openvpn_credentials": { + initialCredentials: map[string]providerCredentials{}, + provider: "protonvpn", + vpnType: "openvpn", + openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"}, + expectedLength: 1, + expectedOpenVPN: true, + }, + "adds_wireguard_credentials": { + initialCredentials: map[string]providerCredentials{}, + provider: "mullvad", + vpnType: "wireguard", + wireguardCreds: &wireguardCredentials{ + PrivateKey: wireguardPrivateKey.String(), + Address: "10.0.0.2/32", + }, + expectedLength: 1, + expectedWireGuard: true, + }, + "preserves_other_protocol": { + initialCredentials: map[string]providerCredentials{ + "protonvpn": { + WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()}, + }, + }, + provider: "protonvpn", + vpnType: "openvpn", + openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"}, + expectedLength: 1, + expectedOpenVPN: true, + expectedWireGuard: true, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + credentials := cloneCredentials(testCase.initialCredentials) + + err := addCredential(credentials, testCase.provider, testCase.vpnType, + testCase.openvpnCredentials, testCase.wireguardCreds) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + providerCredentials := credentials[testCase.provider] + if len(credentials) != testCase.expectedLength { + t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials)) + } + if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN { + t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil) + } + if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard { + t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil) + } + }) + } +} + +func Test_deleteCredential(t *testing.T) { + t.Parallel() + + wireguardPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + testCases := map[string]struct { + initialCredentials map[string]providerCredentials + provider string + vpnType string + expectedLength int + expectedOpenVPN bool + expectedWireGuard bool + }{ + "deletes_openvpn_only": { + initialCredentials: map[string]providerCredentials{ + "protonvpn": { + OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"}, + WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()}, + }, + }, + provider: "protonvpn", + vpnType: "openvpn", + expectedLength: 1, + expectedWireGuard: true, + }, + "deletes_last_protocol_and_provider": { + initialCredentials: map[string]providerCredentials{ + "protonvpn": { + OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"}, + }, + }, + provider: "protonvpn", + vpnType: "openvpn", + expectedLength: 0, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + credentials := cloneCredentials(testCase.initialCredentials) + + err := deleteCredential(credentials, testCase.provider, testCase.vpnType) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(credentials) != testCase.expectedLength { + t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials)) + } + + providerCredentials, exists := credentials[testCase.provider] + if !exists { + return + } + + if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN { + t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil) + } + if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard { + t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil) + } + }) + } +} + +func Test_validateCredentials(t *testing.T) { + t.Parallel() + + wireguardPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + testCases := map[string]struct { + credentials map[string]providerCredentials + wantError bool + }{ + "both_protocols_valid": { + credentials: map[string]providerCredentials{ + "protonvpn": { + OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"}, + WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()}, + }, + }, + }, + "invalid_wireguard_when_both_present": { + credentials: map[string]providerCredentials{ + "protonvpn": { + OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"}, + WireGuard: &wireguardCredentials{PrivateKey: "invalid"}, + }, + }, + wantError: true, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := validateCredentials(testCase.credentials) + if testCase.wantError && err == nil { + t.Fatal("expected an error but got nil") + } + if !testCase.wantError && err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) + } +} + +func Test_marshalLoadCredentials(t *testing.T) { + t.Parallel() + + wireguardPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + credentials := map[string]providerCredentials{ + "mullvad": { + WireGuard: &wireguardCredentials{ + PrivateKey: wireguardPrivateKey.String(), + Address: "10.0.0.2/32", + }, + }, + "protonvpn": { + OpenVPN: &openvpnCredentials{ + Username: "user", + Password: "pass", + }, + }, + } + + encoded, err := marshalCredentials(credentials) + if err != nil { + t.Fatalf("unexpected marshal error: %v", err) + } + + decoded, err := loadCredentials(encoded) + if err != nil { + t.Fatalf("unexpected load error: %v", err) + } + + if len(decoded) != len(credentials) { + t.Fatalf("expected %d providers, got %d", len(credentials), len(decoded)) + } + + if decoded["mullvad"].WireGuard == nil { + t.Fatal("expected mullvad wireguard credentials to be present") + } + if decoded["protonvpn"].OpenVPN == nil { + t.Fatal("expected protonvpn openvpn credentials to be present") + } + if decoded["protonvpn"].OpenVPN.Password != "pass" { + t.Fatalf("expected protonvpn password %q, got %q", "pass", decoded["protonvpn"].OpenVPN.Password) + } +} + +func Test_formatCredentialForDump(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + provider string + vpnType string + providerCredentials providerCredentials + expectedOutput string + wantError bool + }{ + "openvpn": { + provider: "protonvpn", + vpnType: vpnTypeOpenVPN, + providerCredentials: providerCredentials{ + OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"}, + }, + expectedOutput: "provider: protonvpn\n" + + "vpn_type: openvpn\n" + + "username: user\n" + + "password: pass", + }, + "wireguard": { + provider: "mullvad", + vpnType: vpnTypeWireGuard, + providerCredentials: providerCredentials{ + WireGuard: &wireguardCredentials{ + PrivateKey: "private", + Address: "10.0.0.2/32", + PresharedKey: "preshared", + }, + }, + expectedOutput: "provider: mullvad\n" + + "vpn_type: wireguard\n" + + "private_key: private\n" + + "address: 10.0.0.2/32\n" + + "preshared_key: preshared", + }, + "missing_protocol": { + provider: "protonvpn", + vpnType: vpnTypeOpenVPN, + wantError: true, + }, + "unknown_protocol": { + provider: "protonvpn", + vpnType: "other", + wantError: true, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + output, err := formatCredentialForDump( + testCase.provider, + testCase.vpnType, + testCase.providerCredentials, + ) + + if testCase.wantError { + if err == nil { + t.Fatal("expected an error but got nil") + } + return + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if output != testCase.expectedOutput { + t.Fatalf("expected output %q, got %q", testCase.expectedOutput, output) + } + }) + } +} + +func cloneCredentials(credentials map[string]providerCredentials) map[string]providerCredentials { + clone := make(map[string]providerCredentials, len(credentials)) + for provider, providerCredentials := range credentials { + copied := providerCredentials + if providerCredentials.OpenVPN != nil { + openvpnCredentials := *providerCredentials.OpenVPN + copied.OpenVPN = &openvpnCredentials + } + if providerCredentials.WireGuard != nil { + wireguardCredentials := *providerCredentials.WireGuard + copied.WireGuard = &wireguardCredentials + } + clone[provider] = copied + } + return clone +} diff --git a/devrun/internal/encrypt.go b/devrun/internal/encrypt.go new file mode 100644 index 00000000..76702a31 --- /dev/null +++ b/devrun/internal/encrypt.go @@ -0,0 +1,521 @@ +package internal + +import ( + "bufio" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + "maps" + "os" + "slices" + "strings" + "sync" + "syscall" + + "golang.org/x/crypto/scrypt" + "golang.org/x/term" +) + +// Encryption format: [16-byte salt][12-byte nonce][AES-256-GCM ciphertext+tag] +// Key derivation: scrypt(password, salt, N=32768, r=8, p=1, keyLen=32) + +const ( + saltSize = 16 + nonceSize = 12 + keySize = 32 + scryptN = 32768 + scryptR = 8 + scryptP = 1 +) + +// AddCredential prompts for credential values and stores them in the encrypted credentials file. +func AddCredential(ctx context.Context, provider, vpnType string) error { + credentials, password, err := loadCredentialsForMutation(ctx) + if err != nil { + return err + } + + err = promptAndAddCredential(ctx, credentials, provider, vpnType) + if err != nil { + return err + } + + err = validateCredentials(credentials) + if err != nil { + return fmt.Errorf("validating credentials: %w", err) + } + + err = writeEncryptedCredentials(credentials, password) + if err != nil { + return err + } + + fmt.Printf( + "Credentials for provider %q and vpn type %q saved to %s\n", + provider, vpnType, credentialsFilename, + ) + return nil +} + +// DeleteCredential removes credentials for a provider and VPN type +// from the encrypted credentials file. +func DeleteCredential(ctx context.Context, provider, vpnType string) error { + credentials, password, err := loadExistingCredentialsForMutation(ctx) + if err != nil { + return err + } + + err = deleteCredential(credentials, provider, vpnType) + if err != nil { + return err + } + + err = writeEncryptedCredentials(credentials, password) + if err != nil { + return err + } + + fmt.Printf( + "Credentials for provider %q and vpn type %q removed from %s\n", + provider, vpnType, credentialsFilename, + ) + return nil +} + +// DumpCredential decrypts the credential store and prints one provider/vpn-type entry. +func DumpCredential(ctx context.Context, provider, vpnType string) error { + credentials, err := decryptCredentials(ctx) + if err != nil { + return err + } + + providerCredentials, exists := credentials[provider] + if !exists { + existingProviders := slices.Collect(maps.Keys(credentials)) + return fmt.Errorf("provider %q does not exist, available providers are: %s", + provider, strings.Join(existingProviders, ", ")) + } + + output, err := formatCredentialForDump(provider, vpnType, providerCredentials) + if err != nil { + return err + } + + fmt.Println(output) + return nil +} + +// decryptCredentials reads the encrypted credentials file, +// prompts for a password, and returns the decrypted credentials. +func decryptCredentials(ctx context.Context) (map[string]providerCredentials, error) { + password, err := readSecret(ctx, "Enter credentials password: ", false) + if err != nil { + return nil, fmt.Errorf("reading password: %w", err) + } + + plaintext, err := decryptCredentialsFile(password) + if err != nil { + return nil, err + } + + credentials, err := loadCredentials(plaintext) + if err != nil { + return nil, fmt.Errorf("loading credentials: %w", err) + } + + return credentials, nil +} + +func loadCredentialsForMutation(ctx context.Context) ( + credentials map[string]providerCredentials, + password []byte, + err error, +) { + _, err = os.Stat(credentialsFilename) + if os.IsNotExist(err) { + password, err = readPasswordConfirmed(ctx, + "Enter new credentials password: ", + "Confirm new credentials password: ", + ) + if err != nil { + return nil, nil, fmt.Errorf("reading password: %w", err) + } + return make(map[string]providerCredentials), password, nil + } + if err != nil { + return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err) + } + + password, err = readSecret(ctx, "Enter credentials password: ", false) + if err != nil { + return nil, nil, fmt.Errorf("reading password: %w", err) + } + + plaintext, err := decryptCredentialsFile(password) + if err != nil { + return nil, nil, err + } + + credentials, err = loadCredentials(plaintext) + if err != nil { + return nil, nil, fmt.Errorf("loading credentials: %w", err) + } + + return credentials, password, nil +} + +func loadExistingCredentialsForMutation(ctx context.Context) ( + credentials map[string]providerCredentials, + password []byte, + err error, +) { + _, err = os.Stat(credentialsFilename) + if os.IsNotExist(err) { + return nil, nil, fmt.Errorf("%s does not exist", credentialsFilename) + } + if err != nil { + return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err) + } + + password, err = readSecret(ctx, "Enter credentials password: ", false) + if err != nil { + return nil, nil, fmt.Errorf("reading password: %w", err) + } + + plaintext, err := decryptCredentialsFile(password) + if err != nil { + return nil, nil, err + } + + credentials, err = loadCredentials(plaintext) + if err != nil { + return nil, nil, fmt.Errorf("loading credentials: %w", err) + } + + return credentials, password, nil +} + +func promptAndAddCredential( + ctx context.Context, + credentials map[string]providerCredentials, + provider, vpnType string, +) error { + switch vpnType { + case vpnTypeOpenVPN: + username, err := readLine(ctx, "OpenVPN username: ", false) + if err != nil { + return fmt.Errorf("reading username: %w", err) + } + + password, err := readSecret(ctx, "OpenVPN password: ", false) + if err != nil { + return fmt.Errorf("reading password: %w", err) + } + + openvpnCredentials := &openvpnCredentials{ + Username: username, + Password: string(password), + } + err = validateOpenvpnCredentials(provider, openvpnCredentials) + if err != nil { + return err + } + + return addCredential(credentials, provider, vpnType, openvpnCredentials, nil) + + case vpnTypeWireGuard: + privateKey, err := readSecret(ctx, "WireGuard private key: ", false) + if err != nil { + return fmt.Errorf("reading private key: %w", err) + } + + address, err := readLine(ctx, "WireGuard address (optional): ", true) + if err != nil { + return fmt.Errorf("reading address: %w", err) + } + + presharedKey, err := readSecret( + ctx, + "WireGuard preshared key (optional): ", + true, + ) + if err != nil { + return fmt.Errorf("reading preshared key: %w", err) + } + + wireguardCredentials := &wireguardCredentials{ + PrivateKey: string(privateKey), + Address: address, + PresharedKey: string(presharedKey), + } + err = validateWireguardCredentials(provider, wireguardCredentials) + if err != nil { + return err + } + + return addCredential(credentials, provider, vpnType, nil, wireguardCredentials) + + default: + return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType) + } +} + +func writeEncryptedCredentials( + credentials map[string]providerCredentials, + password []byte, +) error { + plaintext, err := marshalCredentials(credentials) + if err != nil { + return fmt.Errorf("encoding credentials: %w", err) + } + + encrypted, err := encryptData(plaintext, password) + if err != nil { + return fmt.Errorf("encrypting credentials: %w", err) + } + + const filePerms = 0o600 + err = os.WriteFile(credentialsFilename, encrypted, filePerms) + if err != nil { + return fmt.Errorf("writing %s: %w", credentialsFilename, err) + } + + return nil +} + +func decryptCredentialsFile(password []byte) ([]byte, error) { + encryptedData, err := os.ReadFile(credentialsFilename) + if err != nil { + return nil, fmt.Errorf("reading %s: %w", credentialsFilename, err) + } + + plaintext, err := decryptData(encryptedData, password) + if err != nil { + return nil, fmt.Errorf("decrypting credentials: %w", err) + } + + return plaintext, nil +} + +func readSecret(ctx context.Context, prompt string, allowEmpty bool) ([]byte, error) { + fmt.Print(prompt) + + passwordFD, err := syscall.Dup(syscall.Stdin) + if err != nil { + fmt.Println() + return nil, fmt.Errorf("duplicating stdin file descriptor: %w", err) + } + + var closeFDOnce sync.Once + closePasswordFD := func() { + closeFDOnce.Do(func() { + _ = syscall.Close(passwordFD) + }) + } + + passwordResult := make(chan struct { + password []byte + err error + }) + + go func() { + password, err := term.ReadPassword(passwordFD) + closePasswordFD() + result := struct { + password []byte + err error + }{ + password: password, + err: err, + } + + select { + case <-ctx.Done(): + return + case passwordResult <- result: + } + }() + + select { + case <-ctx.Done(): + closePasswordFD() + fmt.Println() + return nil, ctx.Err() + case result := <-passwordResult: + closePasswordFD() + fmt.Println() + if result.err != nil { + return nil, fmt.Errorf("reading hidden input from terminal: %w", result.err) + } + if len(result.password) == 0 && !allowEmpty { + return nil, fmt.Errorf("value cannot be empty") + } + return result.password, nil + } +} + +func readLine(ctx context.Context, prompt string, allowEmpty bool) (string, error) { + fmt.Print(prompt) + + inputFD, err := syscall.Dup(syscall.Stdin) + if err != nil { + fmt.Println() + return "", fmt.Errorf("duplicating stdin file descriptor: %w", err) + } + + var closeFDOnce sync.Once + closeInputFD := func() { + closeFDOnce.Do(func() { + _ = syscall.Close(inputFD) + }) + } + + inputResult := make(chan struct { + value string + err error + }) + + go func() { + inputFile := os.NewFile(uintptr(inputFD), "stdin") + reader := bufio.NewReader(inputFile) + value, err := reader.ReadString('\n') + closeInputFD() + value = strings.TrimRight(value, "\r\n") + if err == io.EOF { + err = nil + } + + result := struct { + value string + err error + }{ + value: value, + err: err, + } + + select { + case <-ctx.Done(): + return + case inputResult <- result: + } + }() + + select { + case <-ctx.Done(): + closeInputFD() + fmt.Println() + return "", ctx.Err() + case result := <-inputResult: + closeInputFD() + if result.err != nil { + return "", fmt.Errorf("reading line from terminal: %w", result.err) + } + if result.value == "" && !allowEmpty { + return "", fmt.Errorf("value cannot be empty") + } + return result.value, nil + } +} + +func readPasswordConfirmed( + ctx context.Context, + prompt, confirmationPrompt string, +) ([]byte, error) { + password, err := readSecret(ctx, prompt, false) + if err != nil { + return nil, err + } + + confirmation, err := readSecret(ctx, confirmationPrompt, false) + if err != nil { + return nil, err + } + + if string(password) != string(confirmation) { + return nil, fmt.Errorf("passwords do not match") + } + + return password, nil +} + +func deriveKey(password, salt []byte) ([]byte, error) { + key, err := scrypt.Key(password, salt, scryptN, scryptR, scryptP, keySize) + if err != nil { + return nil, fmt.Errorf("deriving key with scrypt: %w", err) + } + return key, nil +} + +func encryptData(plaintext, password []byte) ([]byte, error) { + salt := make([]byte, saltSize) + _, err := io.ReadFull(rand.Reader, salt) + if err != nil { + return nil, fmt.Errorf("generating salt: %w", err) + } + + key, err := deriveKey(password, salt) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("creating AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("creating GCM: %w", err) + } + + nonce := make([]byte, nonceSize) + _, err = io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, fmt.Errorf("generating nonce: %w", err) + } + + ciphertext := gcm.Seal(nil, nonce, plaintext, nil) + + result := make([]byte, 0, saltSize+nonceSize+len(ciphertext)) + result = append(result, salt...) + result = append(result, nonce...) + result = append(result, ciphertext...) + + return result, nil +} + +func decryptData(data, password []byte) ([]byte, error) { + const minSize = saltSize + nonceSize + 16 // 16 is the GCM tag size + if len(data) < minSize { + return nil, fmt.Errorf("encrypted data too short: %d bytes", len(data)) + } + + salt := data[:saltSize] + nonce := data[saltSize : saltSize+nonceSize] + ciphertext := data[saltSize+nonceSize:] + + key, err := deriveKey(password, salt) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("creating AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("creating GCM: %w", err) + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("decrypting data (wrong password?): %w", err) + } + + return plaintext, nil +} diff --git a/devrun/internal/runner.go b/devrun/internal/runner.go new file mode 100644 index 00000000..85a60fb7 --- /dev/null +++ b/devrun/internal/runner.go @@ -0,0 +1,351 @@ +package internal + +import ( + "context" + "fmt" + "io" + "os" + "strings" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + v1 "github.com/opencontainers/image-spec/specs-go/v1" + "golang.org/x/term" +) + +type containerOptions struct { + env []string + binds []string + ports nat.PortMap + dns []string + devices []container.DeviceMapping + labels map[string]string + capAdd []string +} + +// Run decrypts credentials, builds the container environment, and runs a Gluetun container. +// extraArgs is the list of additional flags (e.g. ["-e", "PORT_FORWARDING=on", "-v", "/host:/container"]). +func Run(ctx context.Context, provider, vpnType string, extraArgs []string, + forceKill <-chan struct{}, +) error { + credentials, err := decryptCredentials(ctx) + if err != nil { + return fmt.Errorf("loading credentials: %w", err) + } + + credentialEnvVars, err := lookupCredentials(credentials, provider, vpnType) + if err != nil { + return err + } + + extraOpts, err := parseExtraArgs(extraArgs) + if err != nil { + return fmt.Errorf("parsing extra flags: %w", err) + } + opts := extraOpts + opts.env = append(opts.env, + "VPN_SERVICE_PROVIDER="+provider, + "VPN_TYPE="+vpnType, + "LOG_LEVEL=debug", + ) + opts.env = append(opts.env, credentialEnvVars...) + opts.capAdd = append(opts.capAdd, "NET_ADMIN") + + return runContainer(ctx, opts, forceKill) +} + +// parseExtraArgs parses extra arguments and maps them to container options. +// Supported flags: +// +// -e, --env KEY=VALUE - environment variable +// -v, --volume SPEC - volume mount (e.g., "/host:/container" or "name:/container") +// -p, --publish PORT:PORT - port mapping +// --dns IP - DNS server +// --device SPEC - device access (e.g., "/dev/net/tun") +// --label KEY=VALUE - container label +// --cap-add CAPABILITY - add Linux capability (e.g., "SYS_PTRACE") +func parseExtraArgs(args []string) (opts containerOptions, err error) { //nolint:gocognit,gocyclo + opts = containerOptions{ + ports: make(nat.PortMap), + labels: make(map[string]string), + } + + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case arg == "-e" || arg == "--env": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + opts.env = append(opts.env, args[i]) + case strings.HasPrefix(arg, "-e="): + opts.env = append(opts.env, strings.TrimPrefix(arg, "-e=")) + case strings.HasPrefix(arg, "--env="): + opts.env = append(opts.env, strings.TrimPrefix(arg, "--env=")) + + case arg == "-v" || arg == "--volume": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + opts.binds = append(opts.binds, args[i]) + case strings.HasPrefix(arg, "-v="): + opts.binds = append(opts.binds, strings.TrimPrefix(arg, "-v=")) + case strings.HasPrefix(arg, "--volume="): + opts.binds = append(opts.binds, strings.TrimPrefix(arg, "--volume=")) + + case arg == "-p" || arg == "--publish": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + if err := parsePortMapping(opts.ports, args[i]); err != nil { + return opts, fmt.Errorf("parsing port mapping: %w", err) + } + case strings.HasPrefix(arg, "-p="): + if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "-p=")); err != nil { + return opts, fmt.Errorf("parsing port mapping: %w", err) + } + case strings.HasPrefix(arg, "--publish="): + if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "--publish=")); err != nil { + return opts, fmt.Errorf("parsing port mapping: %w", err) + } + + case arg == "--dns": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + opts.dns = append(opts.dns, args[i]) + case strings.HasPrefix(arg, "--dns="): + opts.dns = append(opts.dns, strings.TrimPrefix(arg, "--dns=")) + + case arg == "--device": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + parseDeviceMapping(&opts.devices, args[i]) + case strings.HasPrefix(arg, "--device="): + parseDeviceMapping(&opts.devices, strings.TrimPrefix(arg, "--device=")) + + case arg == "--label": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + parseLabel(opts.labels, args[i]) + case strings.HasPrefix(arg, "--label="): + parseLabel(opts.labels, strings.TrimPrefix(arg, "--label=")) + + case arg == "--cap-add": + if i+1 >= len(args) { + return opts, fmt.Errorf("flag %q requires an argument", arg) + } + i++ + opts.capAdd = append(opts.capAdd, args[i]) + case strings.HasPrefix(arg, "--cap-add="): + opts.capAdd = append(opts.capAdd, strings.TrimPrefix(arg, "--cap-add=")) + + default: + return opts, fmt.Errorf("unsupported flag %q", arg) + } + } + return opts, nil +} + +func parsePortMapping(portMap nat.PortMap, spec string) error { + port, bindings, err := nat.ParsePortSpecs([]string{spec}) + if err != nil { + return err + } + for p, binding := range bindings { + portMap[p] = binding + } + for p := range port { + if _, exists := portMap[p]; !exists { + portMap[p] = []nat.PortBinding{} + } + } + return nil +} + +func parseDeviceMapping(devices *[]container.DeviceMapping, spec string) { + parts := strings.SplitN(spec, ":", 3) //nolint:mnd + pathOnHost := parts[0] + pathInContainer := pathOnHost + permissions := "rwm" + + if len(parts) >= 2 { //nolint:mnd + pathInContainer = parts[1] + } + if len(parts) >= 3 { //nolint:mnd + permissions = parts[2] + } + + *devices = append(*devices, container.DeviceMapping{ + PathOnHost: pathOnHost, + PathInContainer: pathInContainer, + CgroupPermissions: permissions, + }) +} + +func parseLabel(labels map[string]string, kv string) { + parts := strings.SplitN(kv, "=", 2) //nolint:mnd + key := parts[0] + value := "" + if len(parts) > 1 { + value = parts[1] + } + labels[key] = value +} + +func runContainer(ctx context.Context, opts containerOptions, forceKill <-chan struct{}) error { + dockerClient, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) + if err != nil { + return fmt.Errorf("creating docker client: %w", err) + } + defer dockerClient.Close() + + hasTTY := term.IsTerminal(int(os.Stdout.Fd())) + + containerConfig := &container.Config{ + Image: "qmcgaw/gluetun", + Env: opts.env, + Labels: opts.labels, + Tty: hasTTY, + } + + mounts := make([]mount.Mount, 0, len(opts.binds)) + for _, bind := range opts.binds { + m, err := parseBindMount(bind) + if err != nil { + return fmt.Errorf("parsing bind mount %q: %w", bind, err) + } + mounts = append(mounts, m) + } + + hostConfig := &container.HostConfig{ + AutoRemove: true, + CapAdd: opts.capAdd, + Binds: opts.binds, + Mounts: mounts, + PortBindings: opts.ports, + DNS: opts.dns, + } + hostConfig.Devices = opts.devices + + networkConfig := &network.NetworkingConfig{} + + platform := (*v1.Platform)(nil) + + const containerName = "gluetun" + response, err := dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, networkConfig, platform, containerName) + if err != nil { + return fmt.Errorf("creating container: %w", err) + } + for _, warning := range response.Warnings { + fmt.Fprintln(os.Stderr, "container creation warning:", warning) + } + containerID := response.ID + + err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) + if err != nil { + return fmt.Errorf("starting container: %w", err) + } + + fmt.Printf("Container started (id: %.12s)\n", containerID) + + streamLogsErr := make(chan error, 1) + go func() { + streamLogsErr <- streamLogs(context.Background(), dockerClient, containerID, hasTTY) + }() + + contextDone := ctx.Done() + forceKillSignal := forceKill + for { + select { + case err := <-streamLogsErr: + if err != nil { + return err + } + return nil + case <-contextDone: + fmt.Fprintln(os.Stderr, "\nReceived interrupt, stopping container (5s timeout)...") + err = stopContainer(dockerClient, containerID) + if err != nil { + fmt.Fprintln(os.Stderr, "stopping container:", err) + } + contextDone = nil + case <-forceKillSignal: + fmt.Fprintln(os.Stderr, "\nReceived second interrupt, killing container...") + err = killContainer(dockerClient, containerID) + if err != nil { + fmt.Fprintln(os.Stderr, "killing container:", err) + } + forceKillSignal = nil + } + } +} + +func parseBindMount(bind string) (mount.Mount, error) { + parts := strings.SplitN(bind, ":", 3) //nolint:mnd + if len(parts) < 2 { //nolint:mnd + return mount.Mount{}, fmt.Errorf("invalid bind mount format: %q (expected source:target[:mode])", bind) + } + + source := parts[0] + target := parts[1] + readOnly := len(parts) > 2 && strings.Contains(parts[2], "ro") //nolint:mnd + + return mount.Mount{ + Type: mount.TypeBind, + Source: source, + Target: target, + ReadOnly: readOnly, + }, nil +} + +func stopContainer(dockerClient *client.Client, containerID string) error { + const stopTimeout = 5 * time.Second + stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout) + defer stopCancel() + timeoutSeconds := int(stopTimeout.Seconds()) + return dockerClient.ContainerStop(stopCtx, containerID, container.StopOptions{Timeout: &timeoutSeconds}) +} + +func killContainer(dockerClient *client.Client, containerID string) error { + return dockerClient.ContainerKill(context.Background(), containerID, "KILL") +} + +func streamLogs(ctx context.Context, dockerClient *client.Client, containerID string, hasTTY bool) error { + logOptions := container.LogsOptions{ + ShowStdout: true, + ShowStderr: true, + Follow: true, + Timestamps: false, + } + + reader, err := dockerClient.ContainerLogs(ctx, containerID, logOptions) + if err != nil { + return fmt.Errorf("getting container logs: %w", err) + } + defer reader.Close() + + if hasTTY { + _, err = io.Copy(os.Stdout, reader) + } else { + _, err = stdcopy.StdCopy(os.Stdout, os.Stderr, reader) + } + if err != nil && err != io.EOF { + return fmt.Errorf("streaming container logs: %w", err) + } + + return nil +}