mirror of
https://github.com/qdm12/gluetun.git
synced 2026-07-04 17:49:51 +02:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 35137cfba0 |
@@ -7,4 +7,3 @@ Dockerfile
|
||||
LICENSE
|
||||
README.md
|
||||
title.svg
|
||||
devrun
|
||||
|
||||
@@ -70,26 +70,6 @@ jobs:
|
||||
- name: Build final image
|
||||
run: docker build -t final-image .
|
||||
|
||||
verify-tools:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: ./devrun/go.mod
|
||||
- run: go test ./...
|
||||
working-directory: ./devrun
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: ./ci/go.mod
|
||||
- run: go test ./...
|
||||
working-directory: ./ci
|
||||
|
||||
verify-private:
|
||||
if: |
|
||||
github.repository == 'qdm12/gluetun' &&
|
||||
@@ -98,7 +78,7 @@ jobs:
|
||||
github.event_name == 'release' ||
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
|
||||
)
|
||||
needs: [ verify ]
|
||||
needs: [verify]
|
||||
runs-on: ubuntu-latest
|
||||
environment: secrets
|
||||
steps:
|
||||
@@ -115,34 +95,10 @@ jobs:
|
||||
run: go build -C ./ci -o runner ./cmd/main.go
|
||||
|
||||
- name: Run Gluetun container with Mullvad configuration
|
||||
run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{
|
||||
secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad
|
||||
run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{ secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad
|
||||
|
||||
- name: Run Gluetun container with ProtonVPN Wireguard and port forwarding
|
||||
configuration
|
||||
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner
|
||||
protonvpn-wireguard-port-forwarding
|
||||
|
||||
- name: Run Gluetun container with ProtonVPN OpenVPN and port forwarding
|
||||
configuration
|
||||
run: echo -e "${{ secrets.PROTONVPN_OPENVPN_USER }}\n${{
|
||||
secrets.PROTONVPN_OPENVPN_PASSWORD }}" | ./ci/runner
|
||||
protonvpn-openvpn-port-forwarding
|
||||
|
||||
- name: Run Gluetun container with Private Internet Access OpenVPN and port
|
||||
forwarding configuration
|
||||
run: echo -e "${{ secrets.PRIVATEINTERNETACCESS_OPENVPN_USER }}\n${{
|
||||
secrets.PRIVATEINTERNETACCESS_OPENVPN_PASSWORD }}" | ./ci/runner
|
||||
private-internet-access-openvpn-port-forwarding
|
||||
|
||||
- name: Run Gluetun container with AirVPN Wireguard configuration
|
||||
run: echo -e "${{ secrets.AIRVPN_WIREGUARD_PRIVATE_KEY }}\n${{
|
||||
secrets.AIRVPN_WIREGUARD_PRESHARED_KEY }}\n${{
|
||||
secrets.AIRVPN_WIREGUARD_ADDRESSES }}" | ./ci/runner airvpn-wireguard
|
||||
|
||||
- name: Run Gluetun container with AirVPN OpenVPN configuration
|
||||
run: echo -e "${{ secrets.AIRVPN_OPENVPN_KEY }}\n${{ secrets.AIRVPN_OPENVPN_CERT
|
||||
}}" | ./ci/runner airvpn-openvpn
|
||||
- name: Run Gluetun container with ProtonVPN configuration
|
||||
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner protonvpn
|
||||
|
||||
codeql:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -169,7 +125,7 @@ jobs:
|
||||
github.event_name == 'release' ||
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
|
||||
)
|
||||
needs: [ verify, verify-private, codeql ]
|
||||
needs: [verify, verify-private, codeql]
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
"ignorePatterns": [
|
||||
{
|
||||
"pattern": "^https://console.substack.com/p/console-72$"
|
||||
},
|
||||
{
|
||||
"pattern": "^https://github.com/passteque/gluetun$"
|
||||
}
|
||||
],
|
||||
"timeout": "20s",
|
||||
|
||||
@@ -68,9 +68,6 @@ linters:
|
||||
- err113
|
||||
- mnd
|
||||
path: ci\/.+\.go
|
||||
- linters:
|
||||
- err113
|
||||
text: "do not define dynamic errors, use wrapped static errors instead"
|
||||
|
||||
paths:
|
||||
- third_party$
|
||||
|
||||
Vendored
-30
@@ -24,15 +24,6 @@
|
||||
"${input:githubRemoteUsername}",
|
||||
"git@github.com:${input:githubRemoteUsername}/gluetun.git"
|
||||
],
|
||||
},
|
||||
{
|
||||
"label": "Devrun",
|
||||
"type": "shell",
|
||||
"command": "go run ./cmd/main.go run ${input:devrunProvider} ${input:devrunVPNProtocol} ${input:devrunExtraFlags}",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/devrun"
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
@@ -56,26 +47,5 @@
|
||||
"type": "promptString",
|
||||
"description": "Please enter a Github username",
|
||||
},
|
||||
{
|
||||
"id": "devrunProvider",
|
||||
"type": "promptString",
|
||||
"description": "Please enter a single provider",
|
||||
},
|
||||
{
|
||||
"id": "devrunVPNProtocol",
|
||||
"type": "pickString",
|
||||
"description": "VPN protocol to use",
|
||||
"options": [
|
||||
"wireguard",
|
||||
"openvpn"
|
||||
],
|
||||
"default": "wireguard"
|
||||
},
|
||||
{
|
||||
"id": "devrunExtraFlags",
|
||||
"type": "promptString",
|
||||
"description": "Extra flags (optional)",
|
||||
"default": ""
|
||||
},
|
||||
]
|
||||
}
|
||||
+19
-4
@@ -142,6 +142,23 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
AMNEZIAWG_I3= \
|
||||
AMNEZIAWG_I4= \
|
||||
AMNEZIAWG_I5= \
|
||||
# Wireguard AmneziaWG userspace obfuscation (requires WIREGUARD_IMPLEMENTATION=amneziawg)
|
||||
AMNEZIAWG_JC=0 \
|
||||
AMNEZIAWG_JMIN=0 \
|
||||
AMNEZIAWG_JMAX=0 \
|
||||
AMNEZIAWG_S1=0 \
|
||||
AMNEZIAWG_S2=0 \
|
||||
AMNEZIAWG_S3=0 \
|
||||
AMNEZIAWG_S4=0 \
|
||||
AMNEZIAWG_H1= \
|
||||
AMNEZIAWG_H2= \
|
||||
AMNEZIAWG_H3= \
|
||||
AMNEZIAWG_H4= \
|
||||
AMNEZIAWG_I1= \
|
||||
AMNEZIAWG_I2= \
|
||||
AMNEZIAWG_I3= \
|
||||
AMNEZIAWG_I4= \
|
||||
AMNEZIAWG_I5= \
|
||||
# VPN server port forwarding
|
||||
VPN_PORT_FORWARDING=off \
|
||||
VPN_PORT_FORWARDING_PROVIDER= \
|
||||
@@ -209,7 +226,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
HEALTH_SMALL_CHECK_TYPE=icmp \
|
||||
HEALTH_RESTART_VPN=on \
|
||||
# DNS
|
||||
DNS_SERVER=on \
|
||||
DNS_UPSTREAM_RESOLVER_TYPE=DoT \
|
||||
# Note: DNS_UPSTREAM_RESOLVERS defaults to cloudflare in code if DNS_UPSTREAM_PLAIN_ADDRESSES is empty
|
||||
DNS_UPSTREAM_RESOLVERS= \
|
||||
@@ -249,7 +265,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
UPDATER_PERIOD=0 \
|
||||
UPDATER_MIN_RATIO=0.8 \
|
||||
UPDATER_VPN_SERVICE_PROVIDERS= \
|
||||
UPDATER_PREFER_DIRECT_DOWNLOAD=no \
|
||||
UPDATER_PROTONVPN_EMAIL= \
|
||||
UPDATER_PROTONVPN_PASSWORD= \
|
||||
# Public IP
|
||||
@@ -258,8 +273,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
PUBLICIP_API=ipinfo,ifconfigco,ip2location,cloudflare \
|
||||
PUBLICIP_API_TOKEN= \
|
||||
# Storage
|
||||
STORAGE_SERVERS_ENABLED=on \
|
||||
STORAGE_SERVERS_DIRECTORY_PATH=/gluetun/servers/ \
|
||||
STORAGE_FILEPATH=/gluetun/servers.json \
|
||||
# Pprof
|
||||
PPROF_ENABLED=no \
|
||||
PPROF_BLOCK_PROFILE_RATE=0 \
|
||||
@@ -267,6 +281,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
PPROF_HTTP_SERVER_ADDRESS=":6060" \
|
||||
# Extras
|
||||
VERSION_INFORMATION=on \
|
||||
BORINGPOLL_GLUETUNCOM=off \
|
||||
TZ= \
|
||||
PUID=1000 \
|
||||
PGID=1000
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Gluetun VPN client
|
||||
|
||||
Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
||||
|
||||
⚠️ This and [gluetun-wiki](https://github.com/qdm12/gluetun-wiki) are the only websites for Gluetun, other websites claiming to be official are scams ⚠️
|
||||
|
||||
🗯️ this repository will be migrated to [github.com/passteque/gluetun](https://github.com/passteque/gluetun) on 2026-05-21, which is a Github organization under my sole control, so don't get alarmed if you get redirected in the coming days 😉 Reason being migrating Github sponsors to the Open source collective due to my personal situation, basically annoying paperwork. On the plus side, it will be more transparent and funds donated will only be used for the project. The Docker image names will remain the same.
|
||||
💁 You can optionally set `BORINGPOLL_GLUETUNCOM=on` to... [poll](./internal/boringpoll/boringpoll.go) that **scammy AI slop** website every few minutes so it costs them too much to keep it up. My gentle email reminders to take it down are being grossly ignored 🤷 This would make me very happy and serve this community.
|
||||
|
||||
Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
|
||||
|
||||

|
||||
|
||||
|
||||
+2
-10
@@ -23,16 +23,8 @@ func main() {
|
||||
switch os.Args[1] {
|
||||
case "mullvad":
|
||||
err = internal.MullvadTest(ctx, logger)
|
||||
case "protonvpn-wireguard-port-forwarding":
|
||||
err = internal.ProtonVPNWireguardPortForwardingTest(ctx, logger)
|
||||
case "protonvpn-openvpn-port-forwarding":
|
||||
err = internal.ProtonVPNOpenVPNPortForwardingTest(ctx, logger)
|
||||
case "private-internet-access-openvpn-port-forwarding":
|
||||
err = internal.PrivateInternetAccessOpenVPNPortForwardingTest(ctx, logger)
|
||||
case "airvpn-wireguard":
|
||||
err = internal.AirVPNWireguardTest(ctx, logger)
|
||||
case "airvpn-openvpn":
|
||||
err = internal.AirVPNOpenVPNTest(ctx, logger)
|
||||
case "protonvpn":
|
||||
err = internal.ProtonVPNTest(ctx, logger)
|
||||
default:
|
||||
err = fmt.Errorf("unknown command: %s", os.Args[1])
|
||||
}
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AirVPNWireguardTest(ctx context.Context, logger Logger) error {
|
||||
expectedSecrets := []string{
|
||||
"Wireguard private key",
|
||||
"Wireguard preshared key",
|
||||
"Wireguard addresses",
|
||||
}
|
||||
secrets, err := readSecrets(ctx, expectedSecrets, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading secrets: %w", err)
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"VPN_SERVICE_PROVIDER=airvpn",
|
||||
"VPN_TYPE=wireguard",
|
||||
"LOG_LEVEL=debug",
|
||||
"SERVER_COUNTRIES=United States",
|
||||
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
|
||||
"WIREGUARD_PRESHARED_KEY=" + secrets[1],
|
||||
"WIREGUARD_ADDRESSES=" + secrets[2],
|
||||
}
|
||||
const timeout = 60 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp}, timeout, logger)
|
||||
}
|
||||
|
||||
func AirVPNOpenVPNTest(ctx context.Context, logger Logger) error {
|
||||
expectedSecrets := []string{
|
||||
"OpenVPN key",
|
||||
"OpenVPN cert",
|
||||
}
|
||||
secrets, err := readSecrets(ctx, expectedSecrets, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading secrets: %w", err)
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"VPN_SERVICE_PROVIDER=airvpn",
|
||||
"VPN_TYPE=openvpn",
|
||||
"LOG_LEVEL=debug",
|
||||
"SERVER_COUNTRIES=United States",
|
||||
"OPENVPN_KEY=" + secrets[0],
|
||||
"OPENVPN_CERT=" + secrets[1],
|
||||
}
|
||||
const timeout = 60 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp}, timeout, logger)
|
||||
}
|
||||
@@ -3,8 +3,6 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
func MullvadTest(ctx context.Context, logger Logger) error {
|
||||
@@ -25,6 +23,5 @@ func MullvadTest(ctx context.Context, logger Logger) error {
|
||||
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
|
||||
"WIREGUARD_ADDRESSES=" + secrets[1],
|
||||
}
|
||||
const timeout = 60 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp}, timeout, logger)
|
||||
return simpleTest(ctx, env, logger)
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
func PrivateInternetAccessOpenVPNPortForwardingTest(ctx context.Context, logger Logger) error {
|
||||
expectedSecrets := []string{
|
||||
"OpenVPN username",
|
||||
"OpenVPN password",
|
||||
}
|
||||
secrets, err := readSecrets(ctx, expectedSecrets, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading secrets: %w", err)
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"VPN_SERVICE_PROVIDER=private internet access",
|
||||
"VPN_TYPE=openvpn",
|
||||
"LOG_LEVEL=debug",
|
||||
"SERVER_REGIONS=CA Montreal",
|
||||
"OPENVPN_USER=" + secrets[0],
|
||||
"OPENVPN_PASSWORD=" + secrets[1],
|
||||
"VPN_PORT_FORWARDING=on",
|
||||
}
|
||||
const timeout = 80 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp, portForwardingRegexp}, timeout, logger)
|
||||
}
|
||||
@@ -3,11 +3,9 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ProtonVPNWireguardPortForwardingTest(ctx context.Context, logger Logger) error {
|
||||
func ProtonVPNTest(ctx context.Context, logger Logger) error {
|
||||
expectedSecrets := []string{
|
||||
"Wireguard private key",
|
||||
}
|
||||
@@ -22,31 +20,6 @@ func ProtonVPNWireguardPortForwardingTest(ctx context.Context, logger Logger) er
|
||||
"LOG_LEVEL=debug",
|
||||
"SERVER_COUNTRIES=United States",
|
||||
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
|
||||
"VPN_PORT_FORWARDING=on",
|
||||
}
|
||||
const timeout = 80 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp, portForwardingRegexp}, timeout, logger)
|
||||
}
|
||||
|
||||
func ProtonVPNOpenVPNPortForwardingTest(ctx context.Context, logger Logger) error {
|
||||
expectedSecrets := []string{
|
||||
"OpenVPN username",
|
||||
"OpenVPN password",
|
||||
}
|
||||
secrets, err := readSecrets(ctx, expectedSecrets, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading secrets: %w", err)
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"VPN_SERVICE_PROVIDER=protonvpn",
|
||||
"VPN_TYPE=openvpn",
|
||||
"LOG_LEVEL=debug",
|
||||
"SERVER_COUNTRIES=United States",
|
||||
"OPENVPN_USER=" + secrets[0],
|
||||
"OPENVPN_PASSWORD=" + secrets[1],
|
||||
"VPN_PORT_FORWARDING=on",
|
||||
}
|
||||
const timeout = 80 * time.Second
|
||||
return runContainerTest(ctx, env, []*regexp.Regexp{successRegexp, portForwardingRegexp}, timeout, logger)
|
||||
return simpleTest(ctx, env, logger)
|
||||
}
|
||||
|
||||
+13
-23
@@ -16,14 +16,8 @@ import (
|
||||
|
||||
func ptrTo[T any](v T) *T { return &v }
|
||||
|
||||
var (
|
||||
successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
|
||||
portForwardingRegexp = regexp.MustCompile(`port forwarded is \d`)
|
||||
)
|
||||
|
||||
func runContainerTest(ctx context.Context, env []string,
|
||||
regexps []*regexp.Regexp, timeout time.Duration, logger Logger,
|
||||
) error {
|
||||
func simpleTest(ctx context.Context, env []string, logger Logger) error {
|
||||
const timeout = 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -63,7 +57,7 @@ func runContainerTest(ctx context.Context, env []string,
|
||||
return fmt.Errorf("starting container: %w", err)
|
||||
}
|
||||
|
||||
return waitForLogLines(ctx, client, containerID, beforeStartTime, regexps, logger)
|
||||
return waitForLogLine(ctx, client, containerID, beforeStartTime, logger)
|
||||
}
|
||||
|
||||
func stopContainer(client *client.Client, containerID string) {
|
||||
@@ -77,8 +71,10 @@ func stopContainer(client *client.Client, containerID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func waitForLogLines(ctx context.Context, client *client.Client, containerID string,
|
||||
beforeStartTime time.Time, regexps []*regexp.Regexp, logger Logger,
|
||||
var successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
|
||||
|
||||
func waitForLogLine(ctx context.Context, client *client.Client, containerID string,
|
||||
beforeStartTime time.Time, logger Logger,
|
||||
) error {
|
||||
logOptions := container.LogsOptions{
|
||||
ShowStdout: true,
|
||||
@@ -92,8 +88,6 @@ func waitForLogLines(ctx context.Context, client *client.Client, containerID str
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
regexpMatched := 0
|
||||
|
||||
var linesSeen []string
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for ctx.Err() == nil {
|
||||
@@ -103,25 +97,21 @@ func waitForLogLines(ctx context.Context, client *client.Client, containerID str
|
||||
line = line[8:]
|
||||
}
|
||||
linesSeen = append(linesSeen, line)
|
||||
regex := regexps[regexpMatched]
|
||||
if regex.MatchString(line) {
|
||||
fmt.Println("✅ Expected line logged:", line)
|
||||
if regexpMatched == len(regexps)-1 {
|
||||
return nil
|
||||
}
|
||||
regexpMatched++
|
||||
if successRegexp.MatchString(line) {
|
||||
fmt.Println("✅ Success line logged")
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
err := scanner.Err()
|
||||
if err != nil && err != io.EOF {
|
||||
logSeenLines(linesSeen)
|
||||
logSeenLines(logger, linesSeen)
|
||||
return fmt.Errorf("reading log stream: %w", err)
|
||||
}
|
||||
|
||||
// The scanner is either done or cannot read because of EOF
|
||||
logger.Info("the log scanner stopped")
|
||||
logSeenLines(linesSeen)
|
||||
logSeenLines(logger, linesSeen)
|
||||
|
||||
// Check if the container is still running
|
||||
inspect, err := client.ContainerInspect(ctx, containerID)
|
||||
@@ -136,7 +126,7 @@ func waitForLogLines(ctx context.Context, client *client.Client, containerID str
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func logSeenLines(lines []string) {
|
||||
func logSeenLines(logger Logger, lines []string) {
|
||||
fmt.Println("Logs seen so far:")
|
||||
for _, line := range lines {
|
||||
fmt.Println(" " + line)
|
||||
|
||||
+29
-7
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
@@ -42,6 +43,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/server"
|
||||
"github.com/qdm12/gluetun/internal/shadowsocks"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/tun"
|
||||
updater "github.com/qdm12/gluetun/internal/updater/loop"
|
||||
"github.com/qdm12/gluetun/internal/updater/resolver"
|
||||
"github.com/qdm12/gluetun/internal/updater/unzip"
|
||||
@@ -78,6 +80,7 @@ func main() {
|
||||
logger := log.New(log.SetLevel(log.LevelInfo))
|
||||
|
||||
args := os.Args
|
||||
tun := tun.New()
|
||||
netLinkDebugLogger := logger.New(log.SetComponent("netlink"))
|
||||
netLinker := netlink.New(netLinkDebugLogger)
|
||||
cli := cli.New()
|
||||
@@ -97,7 +100,7 @@ func main() {
|
||||
|
||||
errorCh := make(chan error)
|
||||
go func() {
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, reader, netLinker, cmder, cli)
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, reader, tun, netLinker, cmder, cli)
|
||||
}()
|
||||
|
||||
// Wait for OS signal or run error
|
||||
@@ -139,10 +142,12 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
var errCommandUnknown = errors.New("command is unknown")
|
||||
|
||||
//nolint:gocognit,gocyclo,maintidx
|
||||
func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
args []string, logger log.LoggerInterface, reader *reader.Reader,
|
||||
netLinker netLinker, cmder RunStarter,
|
||||
tun Tun, netLinker netLinker, cmder RunStarter,
|
||||
cli clier,
|
||||
) error {
|
||||
if len(args) > 1 { // cli operation
|
||||
@@ -160,13 +165,13 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
case "genkey":
|
||||
return cli.GenKey(args[2:])
|
||||
default:
|
||||
return fmt.Errorf("command is unknown: %s", args[1])
|
||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||
}
|
||||
}
|
||||
|
||||
defer fmt.Println(gluetunLogo)
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2026-06-30T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2026-04-30T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,7 +182,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "Your servers data files are now migrated to /gluetun/servers/",
|
||||
Announcement: "Set BORINGPOLL_GLUETUNCOM=on to help combat AI slop and shutdown that scam website",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -240,8 +245,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
storageLogger := logger.New(log.SetComponent("storage"))
|
||||
storage, err := storage.New(storageLogger, *allSettings.Storage.ServersEnabled,
|
||||
allSettings.Storage.ServersPath, allSettings.Storage.LegacyServersFilepath)
|
||||
storage, err := storage.New(storageLogger, *allSettings.Storage.Filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -339,6 +343,19 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return fmt.Errorf("adding local rules: %w", err)
|
||||
}
|
||||
|
||||
const tunDevice = "/dev/net/tun"
|
||||
err = tun.Check(tunDevice)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("checking TUN device: %w (see the Wiki errors/tun page)", err)
|
||||
}
|
||||
logger.Info(err.Error() + "; creating it...")
|
||||
err = tun.Create(tunDevice)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating tun device: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, port := range allSettings.Firewall.InputPorts {
|
||||
for _, defaultRoute := range defaultRoutes {
|
||||
err = firewallConf.SetAllowedPort(ctx, port, defaultRoute.NetInterface)
|
||||
@@ -610,6 +627,11 @@ type clier interface {
|
||||
GenKey(args []string) error
|
||||
}
|
||||
|
||||
type Tun interface {
|
||||
Check(tunDevice string) error
|
||||
Create(tunDevice string) error
|
||||
}
|
||||
|
||||
type RunStarter interface {
|
||||
Run(cmd *exec.Cmd) (output string, err error)
|
||||
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
credentials
|
||||
@@ -1,152 +0,0 @@
|
||||
# devrun
|
||||
|
||||
`devrun` is a small development helper for starting a local `qmcgaw/gluetun` Docker container with provider credentials stored in an encrypted file.
|
||||
|
||||
It solves two practical problems for local development:
|
||||
|
||||
- keeping VPN credentials out of the shell history and out of a plaintext file once setup is complete;
|
||||
- quickly starting a Gluetun container for a specific provider and VPN type with a small set of extra Docker runtime options.
|
||||
|
||||
The tool has four commands:
|
||||
|
||||
- `add-cred`: add or replace credentials for one provider and one VPN type in the encrypted store `credentials`;
|
||||
- `delete-cred`: remove credentials for one provider and one VPN type from the encrypted store `credentials`;
|
||||
- `dump-cred`: print credentials for one provider and one VPN type from the encrypted store `credentials`;
|
||||
- `run`: decrypt credentials on demand, build the required Gluetun environment variables, and run a `qmcgaw/gluetun` container.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Go installed locally
|
||||
- Docker installed and a daemon available to the Docker client
|
||||
- an interactive terminal, since the tool prompts for passwords without echoing them
|
||||
|
||||
The Docker client is created from the standard Docker environment, so settings such as `DOCKER_HOST` are honored.
|
||||
|
||||
## Quick start
|
||||
|
||||
### Add credentials
|
||||
|
||||
Add one credential entry to the encrypted store:
|
||||
|
||||
```sh
|
||||
go run ./cmd/main.go add-cred protonvpn openvpn
|
||||
go run ./cmd/main.go add-cred mullvad wireguard
|
||||
```
|
||||
|
||||
Behavior:
|
||||
|
||||
- if `credentials` does not exist yet, `add-cred` asks for a new credentials password and creates the encrypted store;
|
||||
- if `credentials` already exists, `add-cred` asks for the existing password first, decrypts the store, updates it, and writes it back encrypted;
|
||||
- sensitive fields are read from stdin without echo.
|
||||
|
||||
Prompted values depend on the VPN type:
|
||||
|
||||
- `openvpn`: username and password
|
||||
- `wireguard`: private key, optional address, optional preshared key
|
||||
|
||||
Running `add-cred` again for the same provider and VPN type replaces the existing values for that entry.
|
||||
|
||||
### Delete credentials
|
||||
|
||||
Remove one credential entry from the encrypted store:
|
||||
|
||||
```sh
|
||||
go run ./cmd/main.go delete-cred protonvpn openvpn
|
||||
```
|
||||
|
||||
This asks for the credentials password first, decrypts the store, removes the requested provider and VPN type, and writes the store back encrypted.
|
||||
|
||||
### Dump credentials
|
||||
|
||||
Print one credential entry from the encrypted store:
|
||||
|
||||
```sh
|
||||
go run ./cmd/main.go dump-cred protonvpn openvpn
|
||||
```
|
||||
|
||||
This asks for the credentials password first and then prints the selected provider and VPN type values.
|
||||
|
||||
### Container run
|
||||
|
||||
Run a container using the image `qmcgaw/gluetun` and the encrypted credentials with the `run` command.
|
||||
For example:
|
||||
|
||||
```sh
|
||||
go run ./cmd/main.go run mullvad wireguard
|
||||
go run ./cmd/main.go run protonvpn wireguard -e PORT_FORWARDING=on -p 8000:8000/tcp
|
||||
```
|
||||
|
||||
You will be prompted for the credentials password, the file `credentials` will be decrypted in memory, and the container will be started.
|
||||
|
||||
The following environment variables are always added by the tool:
|
||||
|
||||
- `VPN_SERVICE_PROVIDER=<provider>`
|
||||
- `VPN_TYPE=<vpn-type>`
|
||||
- `LOG_LEVEL=debug`
|
||||
|
||||
The tool also adds `NET_ADMIN` to the container capabilities by default.
|
||||
|
||||
## Credential model
|
||||
|
||||
Internally, the encrypted file stores a binary-encoded map keyed by provider name. Each provider can define `openvpn`, `wireguard`, or both.
|
||||
|
||||
Conceptually, the stored data looks like this:
|
||||
|
||||
- provider `mullvad`: contains `wireguard`
|
||||
- provider `protonvpn`: contains `wireguard`
|
||||
- provider `protonvpn`: contains `openvpn`
|
||||
|
||||
You do not edit this directly. It is stored as encrypted binary data in `credentials`.
|
||||
|
||||
### OpenVPN fields
|
||||
|
||||
- `username` is required;
|
||||
- `password` is required;
|
||||
|
||||
At runtime these map to:
|
||||
|
||||
- `OPENVPN_USER`
|
||||
- `OPENVPN_PASSWORD`
|
||||
|
||||
### WireGuard fields
|
||||
|
||||
- `private_key` is required and must be a valid WireGuard private key;
|
||||
- `address` is optional and must be a valid network prefix if set;
|
||||
- `preshared_key` is optional and must be a valid WireGuard key if set.
|
||||
|
||||
At runtime these map to:
|
||||
|
||||
- `WIREGUARD_PRIVATE_KEY`
|
||||
- `WIREGUARD_ADDRESSES` when `address` is set
|
||||
- `WIREGUARD_PRESHARED_KEY` when `preshared_key` is set
|
||||
|
||||
## Supported extra Docker flags
|
||||
|
||||
The `run` command only accepts a focused subset of Docker-style runtime flags. Unsupported flags return an error.
|
||||
|
||||
Supported flags:
|
||||
|
||||
- `-e`, `--env KEY=VALUE`
|
||||
- `-v`, `--volume SOURCE:TARGET[:mode]`
|
||||
- `-p`, `--publish HOSTPORT:CONTAINERPORT[/proto]`
|
||||
- `--dns IP`
|
||||
- `--device SPEC`
|
||||
- `--label KEY=VALUE`
|
||||
- `--cap-add CAPABILITY`
|
||||
|
||||
## Signals and shutdown
|
||||
|
||||
While the container is running:
|
||||
|
||||
- the first `Ctrl+C` requests a graceful stop with a 5 second timeout;
|
||||
- the second `Ctrl+C` sends a kill signal to the container;
|
||||
- a further interrupt exits the tool immediately.
|
||||
|
||||
## Notes and limitations
|
||||
|
||||
- The container image is fixed to `qmcgaw/gluetun`.
|
||||
- The container name is fixed to `gluetun`.
|
||||
- Credentials are decrypted in memory only during execution.
|
||||
- If the requested provider or VPN type is not present in the encrypted credentials file, the command fails with an explicit error.
|
||||
- The encrypted credential store file is named `credentials`.
|
||||
- This tool is intended for local development convenience, not as a general replacement for `docker run`.
|
||||
@@ -1,156 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/qdm12/gluetun/devrun/internal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
const minArgs = 2
|
||||
if len(os.Args) < minArgs {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch os.Args[1] {
|
||||
case "add-cred":
|
||||
const addCredMinArgs = 4
|
||||
if len(os.Args) < addCredMinArgs {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
`Usage: %s add-cred <provider> <vpn-type>
|
||||
Example: %s add-cred protonvpn wireguard`, os.Args[0], os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
provider := os.Args[2]
|
||||
vpnType := os.Args[3]
|
||||
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||
return internal.AddCredential(ctx, provider, vpnType)
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "add-cred failed:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "delete-cred":
|
||||
const deleteCredMinArgs = 4
|
||||
if len(os.Args) < deleteCredMinArgs {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
`Usage: %s delete-cred <provider> <vpn-type>
|
||||
Example: %s delete-cred protonvpn openvpn`, os.Args[0], os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
provider := os.Args[2]
|
||||
vpnType := os.Args[3]
|
||||
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||
return internal.DeleteCredential(ctx, provider, vpnType)
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "delete-cred failed:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "dump-cred":
|
||||
const dumpCredMinArgs = 4
|
||||
if len(os.Args) < dumpCredMinArgs {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
`Usage: %s dump-cred <provider> <vpn-type>
|
||||
Example: %s dump-cred protonvpn wireguard`, os.Args[0], os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
provider := os.Args[2]
|
||||
vpnType := os.Args[3]
|
||||
err := runWithSignals(func(ctx context.Context, _ <-chan struct{}) error {
|
||||
return internal.DumpCredential(ctx, provider, vpnType)
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "dump-cred failed:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "run":
|
||||
const runMinArgs = 4
|
||||
if len(os.Args) < runMinArgs {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
`Usage: %s run <provider> <vpn-type> [extra docker flags...]
|
||||
Example: %s run mullvad wireguard -e SERVER_COUNTRIES=USA`, os.Args[0], os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
provider := os.Args[2]
|
||||
vpnType := os.Args[3]
|
||||
extraArgs := os.Args[4:]
|
||||
err := runWithSignals(func(ctx context.Context, forceKill <-chan struct{}) error {
|
||||
return internal.Run(ctx, provider, vpnType, extraArgs, forceKill)
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "run failed:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "unknown command:", os.Args[1])
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Fprintf(os.Stderr, `Usage: %s <command> [args...]
|
||||
|
||||
Commands:
|
||||
add-cred <provider> <vpn-type>
|
||||
Add or replace credentials in the encrypted credentials store.
|
||||
delete-cred <provider> <vpn-type>
|
||||
Delete credentials from the encrypted credentials store.
|
||||
dump-cred <provider> <vpn-type>
|
||||
Print credentials for a provider and VPN type pair.
|
||||
run <provider> <vpn-type> [flags...]
|
||||
Decrypt credentials and run a Gluetun container.
|
||||
Extra flags (e.g. -e PORT_FORWARDING=on) are passed to docker run.`,
|
||||
os.Args[0])
|
||||
}
|
||||
|
||||
func runWithSignals(runFn func(ctx context.Context, forceKill <-chan struct{}) error) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
const signalBufferSize = 3
|
||||
sigCh := make(chan os.Signal, signalBufferSize)
|
||||
signal.Notify(sigCh, os.Interrupt)
|
||||
defer signal.Stop(sigCh)
|
||||
|
||||
forceKill := make(chan struct{})
|
||||
stopSignalLoop := make(chan struct{})
|
||||
signalLoopDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(signalLoopDone)
|
||||
|
||||
const secondInterrupt = 2
|
||||
interruptCount := uint(0)
|
||||
forceKillSent := false
|
||||
for {
|
||||
select {
|
||||
case <-stopSignalLoop:
|
||||
return
|
||||
case <-sigCh:
|
||||
interruptCount++
|
||||
switch interruptCount {
|
||||
case 1:
|
||||
cancel()
|
||||
case secondInterrupt:
|
||||
if !forceKillSent {
|
||||
close(forceKill)
|
||||
forceKillSent = true
|
||||
}
|
||||
default:
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err := runFn(ctx, forceKill)
|
||||
close(stopSignalLoop)
|
||||
<-signalLoopDone
|
||||
return err
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
module github.com/qdm12/gluetun/devrun
|
||||
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
github.com/docker/go-connections v0.7.0
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.2 // indirect
|
||||
github.com/morikuni/aec v1.1.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
gotest.tools/v3 v3.5.2 // indirect
|
||||
)
|
||||
-105
@@ -1,105 +0,0 @@
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
|
||||
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf0mN39c=
|
||||
github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||
github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ=
|
||||
github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
@@ -1,251 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const credentialsFilename = "credentials"
|
||||
|
||||
const (
|
||||
vpnTypeOpenVPN = "openvpn"
|
||||
vpnTypeWireGuard = "wireguard"
|
||||
)
|
||||
|
||||
type providerCredentials struct {
|
||||
OpenVPN *openvpnCredentials
|
||||
WireGuard *wireguardCredentials
|
||||
}
|
||||
|
||||
type openvpnCredentials struct {
|
||||
Username string
|
||||
Password string
|
||||
Key string
|
||||
Cert string
|
||||
}
|
||||
|
||||
type wireguardCredentials struct {
|
||||
PrivateKey string
|
||||
Address string
|
||||
PresharedKey string
|
||||
}
|
||||
|
||||
func loadCredentials(data []byte) (map[string]providerCredentials, error) {
|
||||
credentials := make(map[string]providerCredentials)
|
||||
err := gob.NewDecoder(bytes.NewReader(data)).Decode(&credentials)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding credentials: %w", err)
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func marshalCredentials(credentials map[string]providerCredentials) ([]byte, error) {
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
err := gob.NewEncoder(buffer).Encode(credentials)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding credentials: %w", err)
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func validateCredentials(providerNameToCredentials map[string]providerCredentials) error {
|
||||
for provider, credentials := range providerNameToCredentials {
|
||||
if credentials.OpenVPN == nil && credentials.WireGuard == nil {
|
||||
return fmt.Errorf("provider %q has no openvpn or wireguard credentials", provider)
|
||||
}
|
||||
if credentials.OpenVPN != nil {
|
||||
err := validateOpenvpnCredentials(provider, credentials.OpenVPN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if credentials.WireGuard != nil {
|
||||
err := validateWireguardCredentials(provider, credentials.WireGuard)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpenvpnCredentials(provider string, creds *openvpnCredentials) error {
|
||||
switch {
|
||||
case creds.Username == "" && creds.Password != "":
|
||||
return fmt.Errorf("provider %q openvpn credentials are missing the username", provider)
|
||||
case creds.Password == "" && creds.Username != "":
|
||||
return fmt.Errorf("provider %q openvpn credentials are missing the password", provider)
|
||||
case creds.Username == "" && creds.Password == "" && creds.Key == "" && creds.Cert == "":
|
||||
return fmt.Errorf("provider %q openvpn credentials are missing the username and password", provider)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateWireguardCredentials(provider string, creds *wireguardCredentials) error {
|
||||
if creds.PrivateKey == "" {
|
||||
return fmt.Errorf("provider %q wireguard credentials are missing the private key", provider)
|
||||
} else if _, err := wgtypes.ParseKey(creds.PrivateKey); err != nil {
|
||||
return fmt.Errorf("provider %q wireguard credentials have an invalid private key: %w", provider, err)
|
||||
}
|
||||
|
||||
if creds.Address != "" {
|
||||
_, err := netip.ParsePrefix(creds.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("provider %q wireguard credentials have an invalid address %q: %w", provider, creds.Address, err)
|
||||
}
|
||||
}
|
||||
|
||||
if creds.PresharedKey != "" {
|
||||
if _, err := wgtypes.ParseKey(creds.PresharedKey); err != nil {
|
||||
return fmt.Errorf("provider %q wireguard credentials have an invalid preshared key: %w", provider, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupCredentials(credentials map[string]providerCredentials, provider, vpnType string) ([]string, error) {
|
||||
providerCreds, exists := credentials[provider]
|
||||
if !exists {
|
||||
existing := slices.Collect(maps.Keys(credentials))
|
||||
return nil, fmt.Errorf("no credentials found for provider %q, available providers are: %s",
|
||||
provider, strings.Join(existing, ", "))
|
||||
}
|
||||
|
||||
switch vpnType {
|
||||
case vpnTypeWireGuard:
|
||||
if providerCreds.WireGuard == nil {
|
||||
return nil, fmt.Errorf("no wireguard credentials found for provider %q", provider)
|
||||
}
|
||||
return buildWireGuardEnv(providerCreds.WireGuard), nil
|
||||
case vpnTypeOpenVPN:
|
||||
if providerCreds.OpenVPN == nil {
|
||||
return nil, fmt.Errorf("no openvpn credentials found for provider %q", provider)
|
||||
}
|
||||
return buildOpenvpnEnv(providerCreds.OpenVPN), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||
}
|
||||
}
|
||||
|
||||
func buildWireGuardEnv(creds *wireguardCredentials) []string {
|
||||
envVars := []string{
|
||||
"WIREGUARD_PRIVATE_KEY=" + creds.PrivateKey,
|
||||
}
|
||||
if creds.Address != "" {
|
||||
envVars = append(envVars, "WIREGUARD_ADDRESSES="+creds.Address)
|
||||
}
|
||||
if creds.PresharedKey != "" {
|
||||
envVars = append(envVars, "WIREGUARD_PRESHARED_KEY="+creds.PresharedKey)
|
||||
}
|
||||
return envVars
|
||||
}
|
||||
|
||||
func buildOpenvpnEnv(creds *openvpnCredentials) []string {
|
||||
return []string{
|
||||
"OPENVPN_USER=" + creds.Username,
|
||||
"OPENVPN_PASSWORD=" + creds.Password,
|
||||
"OPENVPN_KEY=" + creds.Key,
|
||||
"OPENVPN_CERT=" + creds.Cert,
|
||||
}
|
||||
}
|
||||
|
||||
func addCredential(credentials map[string]providerCredentials, provider, vpnType string,
|
||||
openvpnCredentials *openvpnCredentials, wireguardCredentials *wireguardCredentials,
|
||||
) error {
|
||||
providerCredentials := credentials[provider]
|
||||
|
||||
switch vpnType {
|
||||
case vpnTypeOpenVPN:
|
||||
providerCredentials.OpenVPN = openvpnCredentials
|
||||
case vpnTypeWireGuard:
|
||||
providerCredentials.WireGuard = wireguardCredentials
|
||||
default:
|
||||
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||
}
|
||||
|
||||
credentials[provider] = providerCredentials
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteCredential(credentials map[string]providerCredentials, provider, vpnType string) error {
|
||||
providerCredentials, exists := credentials[provider]
|
||||
if !exists {
|
||||
return fmt.Errorf("provider %q does not exist", provider)
|
||||
}
|
||||
|
||||
switch vpnType {
|
||||
case vpnTypeOpenVPN:
|
||||
if providerCredentials.OpenVPN == nil {
|
||||
return fmt.Errorf("provider %q has no openvpn credentials", provider)
|
||||
}
|
||||
providerCredentials.OpenVPN = nil
|
||||
case vpnTypeWireGuard:
|
||||
if providerCredentials.WireGuard == nil {
|
||||
return fmt.Errorf("provider %q has no wireguard credentials", provider)
|
||||
}
|
||||
providerCredentials.WireGuard = nil
|
||||
default:
|
||||
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||
}
|
||||
|
||||
if providerCredentials.OpenVPN == nil && providerCredentials.WireGuard == nil {
|
||||
delete(credentials, provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
credentials[provider] = providerCredentials
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatCredentialForDump(provider, vpnType string,
|
||||
providerCredentials providerCredentials,
|
||||
) (output string, err error) {
|
||||
var builder strings.Builder
|
||||
|
||||
builder.WriteString("provider: ")
|
||||
builder.WriteString(provider)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("vpn_type: ")
|
||||
builder.WriteString(vpnType)
|
||||
builder.WriteString("\n")
|
||||
|
||||
switch vpnType {
|
||||
case vpnTypeOpenVPN:
|
||||
if providerCredentials.OpenVPN == nil {
|
||||
return "", fmt.Errorf("no openvpn credentials found for provider %q", provider)
|
||||
}
|
||||
builder.WriteString("username: ")
|
||||
builder.WriteString(providerCredentials.OpenVPN.Username)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("password: ")
|
||||
builder.WriteString(providerCredentials.OpenVPN.Password)
|
||||
builder.WriteString("\nkey: ")
|
||||
builder.WriteString(providerCredentials.OpenVPN.Key)
|
||||
builder.WriteString("\ncert: ")
|
||||
builder.WriteString(providerCredentials.OpenVPN.Cert)
|
||||
builder.WriteString("\n")
|
||||
case vpnTypeWireGuard:
|
||||
if providerCredentials.WireGuard == nil {
|
||||
return "", fmt.Errorf("no wireguard credentials found for provider %q", provider)
|
||||
}
|
||||
builder.WriteString("private_key: ")
|
||||
builder.WriteString(providerCredentials.WireGuard.PrivateKey)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("address: ")
|
||||
builder.WriteString(providerCredentials.WireGuard.Address)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("preshared_key: ")
|
||||
builder.WriteString(providerCredentials.WireGuard.PresharedKey)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func Test_addCredential(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
initialCredentials map[string]providerCredentials
|
||||
provider string
|
||||
vpnType string
|
||||
openvpnCredentials *openvpnCredentials
|
||||
wireguardCreds *wireguardCredentials
|
||||
expectedLength int
|
||||
expectedOpenVPN bool
|
||||
expectedWireGuard bool
|
||||
}{
|
||||
"adds_openvpn_credentials": {
|
||||
initialCredentials: map[string]providerCredentials{},
|
||||
provider: "protonvpn",
|
||||
vpnType: "openvpn",
|
||||
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
expectedLength: 1,
|
||||
expectedOpenVPN: true,
|
||||
},
|
||||
"adds_wireguard_credentials": {
|
||||
initialCredentials: map[string]providerCredentials{},
|
||||
provider: "mullvad",
|
||||
vpnType: "wireguard",
|
||||
wireguardCreds: &wireguardCredentials{
|
||||
PrivateKey: wireguardPrivateKey.String(),
|
||||
Address: "10.0.0.2/32",
|
||||
},
|
||||
expectedLength: 1,
|
||||
expectedWireGuard: true,
|
||||
},
|
||||
"preserves_other_protocol": {
|
||||
initialCredentials: map[string]providerCredentials{
|
||||
"protonvpn": {
|
||||
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||
},
|
||||
},
|
||||
provider: "protonvpn",
|
||||
vpnType: "openvpn",
|
||||
openvpnCredentials: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
expectedLength: 1,
|
||||
expectedOpenVPN: true,
|
||||
expectedWireGuard: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
credentials := cloneCredentials(testCase.initialCredentials)
|
||||
|
||||
err := addCredential(credentials, testCase.provider, testCase.vpnType,
|
||||
testCase.openvpnCredentials, testCase.wireguardCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
providerCredentials := credentials[testCase.provider]
|
||||
if len(credentials) != testCase.expectedLength {
|
||||
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
||||
}
|
||||
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
||||
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
||||
}
|
||||
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
||||
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_deleteCredential(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
initialCredentials map[string]providerCredentials
|
||||
provider string
|
||||
vpnType string
|
||||
expectedLength int
|
||||
expectedOpenVPN bool
|
||||
expectedWireGuard bool
|
||||
}{
|
||||
"deletes_openvpn_only": {
|
||||
initialCredentials: map[string]providerCredentials{
|
||||
"protonvpn": {
|
||||
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||
},
|
||||
},
|
||||
provider: "protonvpn",
|
||||
vpnType: "openvpn",
|
||||
expectedLength: 1,
|
||||
expectedWireGuard: true,
|
||||
},
|
||||
"deletes_last_protocol_and_provider": {
|
||||
initialCredentials: map[string]providerCredentials{
|
||||
"protonvpn": {
|
||||
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
},
|
||||
},
|
||||
provider: "protonvpn",
|
||||
vpnType: "openvpn",
|
||||
expectedLength: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
credentials := cloneCredentials(testCase.initialCredentials)
|
||||
|
||||
err := deleteCredential(credentials, testCase.provider, testCase.vpnType)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(credentials) != testCase.expectedLength {
|
||||
t.Fatalf("expected %d providers, got %d", testCase.expectedLength, len(credentials))
|
||||
}
|
||||
|
||||
providerCredentials, exists := credentials[testCase.provider]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
if (providerCredentials.OpenVPN != nil) != testCase.expectedOpenVPN {
|
||||
t.Fatalf("expected openvpn presence %t, got %t", testCase.expectedOpenVPN, providerCredentials.OpenVPN != nil)
|
||||
}
|
||||
if (providerCredentials.WireGuard != nil) != testCase.expectedWireGuard {
|
||||
t.Fatalf("expected wireguard presence %t, got %t", testCase.expectedWireGuard, providerCredentials.WireGuard != nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validateCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
credentials map[string]providerCredentials
|
||||
wantError bool
|
||||
}{
|
||||
"both_protocols_valid": {
|
||||
credentials: map[string]providerCredentials{
|
||||
"protonvpn": {
|
||||
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
WireGuard: &wireguardCredentials{PrivateKey: wireguardPrivateKey.String()},
|
||||
},
|
||||
},
|
||||
},
|
||||
"invalid_wireguard_when_both_present": {
|
||||
credentials: map[string]providerCredentials{
|
||||
"protonvpn": {
|
||||
OpenVPN: &openvpnCredentials{Username: "user", Password: "pass"},
|
||||
WireGuard: &wireguardCredentials{PrivateKey: "invalid"},
|
||||
},
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateCredentials(testCase.credentials)
|
||||
if testCase.wantError && err == nil {
|
||||
t.Fatal("expected an error but got nil")
|
||||
}
|
||||
if !testCase.wantError && err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_marshalLoadCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguardPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
credentials := map[string]providerCredentials{
|
||||
"mullvad": {
|
||||
WireGuard: &wireguardCredentials{
|
||||
PrivateKey: wireguardPrivateKey.String(),
|
||||
Address: "10.0.0.2/32",
|
||||
},
|
||||
},
|
||||
"protonvpn": {
|
||||
OpenVPN: &openvpnCredentials{
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
encoded, err := marshalCredentials(credentials)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected marshal error: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := loadCredentials(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected load error: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded) != len(credentials) {
|
||||
t.Fatalf("expected %d providers, got %d", len(credentials), len(decoded))
|
||||
}
|
||||
|
||||
if decoded["mullvad"].WireGuard == nil {
|
||||
t.Fatal("expected mullvad wireguard credentials to be present")
|
||||
}
|
||||
if decoded["protonvpn"].OpenVPN == nil {
|
||||
t.Fatal("expected protonvpn openvpn credentials to be present")
|
||||
}
|
||||
if decoded["protonvpn"].OpenVPN.Password != "pass" {
|
||||
t.Fatalf("expected protonvpn password %q, got %q", "pass", decoded["protonvpn"].OpenVPN.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_formatCredentialForDump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
provider string
|
||||
vpnType string
|
||||
providerCredentials providerCredentials
|
||||
expectedOutput string
|
||||
wantError bool
|
||||
}{
|
||||
"openvpn": {
|
||||
provider: "protonvpn",
|
||||
vpnType: vpnTypeOpenVPN,
|
||||
providerCredentials: providerCredentials{
|
||||
OpenVPN: &openvpnCredentials{
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Key: "key",
|
||||
Cert: "cert",
|
||||
},
|
||||
},
|
||||
expectedOutput: "provider: protonvpn\n" +
|
||||
"vpn_type: openvpn\n" +
|
||||
"username: user\n" +
|
||||
"password: pass\n" +
|
||||
"key: key\n" +
|
||||
"cert: cert\n",
|
||||
},
|
||||
"wireguard": {
|
||||
provider: "mullvad",
|
||||
vpnType: vpnTypeWireGuard,
|
||||
providerCredentials: providerCredentials{
|
||||
WireGuard: &wireguardCredentials{
|
||||
PrivateKey: "private",
|
||||
Address: "10.0.0.2/32",
|
||||
PresharedKey: "preshared",
|
||||
},
|
||||
},
|
||||
expectedOutput: "provider: mullvad\n" +
|
||||
"vpn_type: wireguard\n" +
|
||||
"private_key: private\n" +
|
||||
"address: 10.0.0.2/32\n" +
|
||||
"preshared_key: preshared",
|
||||
},
|
||||
"missing_protocol": {
|
||||
provider: "protonvpn",
|
||||
vpnType: vpnTypeOpenVPN,
|
||||
wantError: true,
|
||||
},
|
||||
"unknown_protocol": {
|
||||
provider: "protonvpn",
|
||||
vpnType: "other",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
output, err := formatCredentialForDump(
|
||||
testCase.provider,
|
||||
testCase.vpnType,
|
||||
testCase.providerCredentials,
|
||||
)
|
||||
|
||||
if testCase.wantError {
|
||||
if err == nil {
|
||||
t.Fatal("expected an error but got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if output != testCase.expectedOutput {
|
||||
t.Fatalf("expected output %q, got %q", testCase.expectedOutput, output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cloneCredentials(credentials map[string]providerCredentials) map[string]providerCredentials {
|
||||
clone := make(map[string]providerCredentials, len(credentials))
|
||||
for provider, providerCredentials := range credentials {
|
||||
copied := providerCredentials
|
||||
if providerCredentials.OpenVPN != nil {
|
||||
openvpnCredentials := *providerCredentials.OpenVPN
|
||||
copied.OpenVPN = &openvpnCredentials
|
||||
}
|
||||
if providerCredentials.WireGuard != nil {
|
||||
wireguardCredentials := *providerCredentials.WireGuard
|
||||
copied.WireGuard = &wireguardCredentials
|
||||
}
|
||||
clone[provider] = copied
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -1,533 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/scrypt"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Encryption format: [16-byte salt][12-byte nonce][AES-256-GCM ciphertext+tag]
|
||||
// Key derivation: scrypt(password, salt, N=32768, r=8, p=1, keyLen=32)
|
||||
|
||||
const (
|
||||
saltSize = 16
|
||||
nonceSize = 12
|
||||
keySize = 32
|
||||
scryptN = 32768
|
||||
scryptR = 8
|
||||
scryptP = 1
|
||||
)
|
||||
|
||||
// AddCredential prompts for credential values and stores them in the encrypted credentials file.
|
||||
func AddCredential(ctx context.Context, provider, vpnType string) error {
|
||||
credentials, password, err := loadCredentialsForMutation(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = promptAndAddCredential(ctx, credentials, provider, vpnType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateCredentials(credentials)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating credentials: %w", err)
|
||||
}
|
||||
|
||||
err = writeEncryptedCredentials(credentials, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf(
|
||||
"Credentials for provider %q and vpn type %q saved to %s\n",
|
||||
provider, vpnType, credentialsFilename,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteCredential removes credentials for a provider and VPN type
|
||||
// from the encrypted credentials file.
|
||||
func DeleteCredential(ctx context.Context, provider, vpnType string) error {
|
||||
credentials, password, err := loadExistingCredentialsForMutation(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = deleteCredential(credentials, provider, vpnType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = writeEncryptedCredentials(credentials, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf(
|
||||
"Credentials for provider %q and vpn type %q removed from %s\n",
|
||||
provider, vpnType, credentialsFilename,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DumpCredential decrypts the credential store and prints one provider/vpn-type entry.
|
||||
func DumpCredential(ctx context.Context, provider, vpnType string) error {
|
||||
credentials, err := decryptCredentials(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
providerCredentials, exists := credentials[provider]
|
||||
if !exists {
|
||||
existingProviders := slices.Collect(maps.Keys(credentials))
|
||||
return fmt.Errorf("provider %q does not exist, available providers are: %s",
|
||||
provider, strings.Join(existingProviders, ", "))
|
||||
}
|
||||
|
||||
output, err := formatCredentialForDump(provider, vpnType, providerCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(output)
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptCredentials reads the encrypted credentials file,
|
||||
// prompts for a password, and returns the decrypted credentials.
|
||||
func decryptCredentials(ctx context.Context) (map[string]providerCredentials, error) {
|
||||
password, err := readSecret(ctx, "Enter credentials password: ", false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptCredentialsFile(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials, err := loadCredentials(plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading credentials: %w", err)
|
||||
}
|
||||
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func loadCredentialsForMutation(ctx context.Context) (
|
||||
credentials map[string]providerCredentials,
|
||||
password []byte,
|
||||
err error,
|
||||
) {
|
||||
_, err = os.Stat(credentialsFilename)
|
||||
if os.IsNotExist(err) {
|
||||
password, err = readPasswordConfirmed(ctx,
|
||||
"Enter new credentials password: ",
|
||||
"Confirm new credentials password: ",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
return make(map[string]providerCredentials), password, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
|
||||
}
|
||||
|
||||
password, err = readSecret(ctx, "Enter credentials password: ", false)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptCredentialsFile(password)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
credentials, err = loadCredentials(plaintext)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading credentials: %w", err)
|
||||
}
|
||||
|
||||
return credentials, password, nil
|
||||
}
|
||||
|
||||
func loadExistingCredentialsForMutation(ctx context.Context) (
|
||||
credentials map[string]providerCredentials,
|
||||
password []byte,
|
||||
err error,
|
||||
) {
|
||||
_, err = os.Stat(credentialsFilename)
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil, fmt.Errorf("%s does not exist", credentialsFilename)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("stating %s: %w", credentialsFilename, err)
|
||||
}
|
||||
|
||||
password, err = readSecret(ctx, "Enter credentials password: ", false)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptCredentialsFile(password)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
credentials, err = loadCredentials(plaintext)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading credentials: %w", err)
|
||||
}
|
||||
|
||||
return credentials, password, nil
|
||||
}
|
||||
|
||||
func promptAndAddCredential(
|
||||
ctx context.Context,
|
||||
credentials map[string]providerCredentials,
|
||||
provider, vpnType string,
|
||||
) error {
|
||||
switch vpnType {
|
||||
case vpnTypeOpenVPN:
|
||||
username, err := readLine(ctx, "OpenVPN username: ", true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading username: %w", err)
|
||||
}
|
||||
|
||||
password, err := readSecret(ctx, "OpenVPN password: ", username == "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
|
||||
key, err := readSecret(ctx, "OpenVPN key: ", true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading key: %w", err)
|
||||
}
|
||||
|
||||
cert, err := readSecret(ctx, "OpenVPN cert: ", true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading cert: %w", err)
|
||||
}
|
||||
|
||||
openvpnCredentials := &openvpnCredentials{
|
||||
Username: username,
|
||||
Password: string(password),
|
||||
Key: string(key),
|
||||
Cert: string(cert),
|
||||
}
|
||||
err = validateOpenvpnCredentials(provider, openvpnCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return addCredential(credentials, provider, vpnType, openvpnCredentials, nil)
|
||||
|
||||
case vpnTypeWireGuard:
|
||||
privateKey, err := readSecret(ctx, "WireGuard private key: ", false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading private key: %w", err)
|
||||
}
|
||||
|
||||
address, err := readLine(ctx, "WireGuard address (optional): ", true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading address: %w", err)
|
||||
}
|
||||
|
||||
presharedKey, err := readSecret(
|
||||
ctx,
|
||||
"WireGuard preshared key (optional): ",
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading preshared key: %w", err)
|
||||
}
|
||||
|
||||
wireguardCredentials := &wireguardCredentials{
|
||||
PrivateKey: string(privateKey),
|
||||
Address: address,
|
||||
PresharedKey: string(presharedKey),
|
||||
}
|
||||
err = validateWireguardCredentials(provider, wireguardCredentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return addCredential(credentials, provider, vpnType, nil, wireguardCredentials)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown vpn type %q, must be wireguard or openvpn", vpnType)
|
||||
}
|
||||
}
|
||||
|
||||
func writeEncryptedCredentials(
|
||||
credentials map[string]providerCredentials,
|
||||
password []byte,
|
||||
) error {
|
||||
plaintext, err := marshalCredentials(credentials)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding credentials: %w", err)
|
||||
}
|
||||
|
||||
encrypted, err := encryptData(plaintext, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypting credentials: %w", err)
|
||||
}
|
||||
|
||||
const filePerms = 0o600
|
||||
err = os.WriteFile(credentialsFilename, encrypted, filePerms)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing %s: %w", credentialsFilename, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptCredentialsFile(password []byte) ([]byte, error) {
|
||||
encryptedData, err := os.ReadFile(credentialsFilename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading %s: %w", credentialsFilename, err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptData(encryptedData, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting credentials: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func readSecret(ctx context.Context, prompt string, allowEmpty bool) ([]byte, error) {
|
||||
fmt.Print(prompt)
|
||||
|
||||
passwordFD, err := syscall.Dup(syscall.Stdin)
|
||||
if err != nil {
|
||||
fmt.Println()
|
||||
return nil, fmt.Errorf("duplicating stdin file descriptor: %w", err)
|
||||
}
|
||||
|
||||
var closeFDOnce sync.Once
|
||||
closePasswordFD := func() {
|
||||
closeFDOnce.Do(func() {
|
||||
_ = syscall.Close(passwordFD)
|
||||
})
|
||||
}
|
||||
|
||||
passwordResult := make(chan struct {
|
||||
password []byte
|
||||
err error
|
||||
})
|
||||
|
||||
go func() {
|
||||
password, err := term.ReadPassword(passwordFD)
|
||||
closePasswordFD()
|
||||
result := struct {
|
||||
password []byte
|
||||
err error
|
||||
}{
|
||||
password: password,
|
||||
err: err,
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case passwordResult <- result:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closePasswordFD()
|
||||
fmt.Println()
|
||||
return nil, ctx.Err()
|
||||
case result := <-passwordResult:
|
||||
closePasswordFD()
|
||||
fmt.Println()
|
||||
if result.err != nil {
|
||||
return nil, fmt.Errorf("reading hidden input from terminal: %w", result.err)
|
||||
}
|
||||
if len(result.password) == 0 && !allowEmpty {
|
||||
return nil, fmt.Errorf("value cannot be empty")
|
||||
}
|
||||
return result.password, nil
|
||||
}
|
||||
}
|
||||
|
||||
func readLine(ctx context.Context, prompt string, allowEmpty bool) (string, error) {
|
||||
fmt.Print(prompt)
|
||||
|
||||
inputFD, err := syscall.Dup(syscall.Stdin)
|
||||
if err != nil {
|
||||
fmt.Println()
|
||||
return "", fmt.Errorf("duplicating stdin file descriptor: %w", err)
|
||||
}
|
||||
|
||||
var closeFDOnce sync.Once
|
||||
closeInputFD := func() {
|
||||
closeFDOnce.Do(func() {
|
||||
_ = syscall.Close(inputFD)
|
||||
})
|
||||
}
|
||||
|
||||
inputResult := make(chan struct {
|
||||
value string
|
||||
err error
|
||||
})
|
||||
|
||||
go func() {
|
||||
inputFile := os.NewFile(uintptr(inputFD), "stdin")
|
||||
reader := bufio.NewReader(inputFile)
|
||||
value, err := reader.ReadString('\n')
|
||||
closeInputFD()
|
||||
value = strings.TrimRight(value, "\r\n")
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
|
||||
result := struct {
|
||||
value string
|
||||
err error
|
||||
}{
|
||||
value: value,
|
||||
err: err,
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case inputResult <- result:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeInputFD()
|
||||
fmt.Println()
|
||||
return "", ctx.Err()
|
||||
case result := <-inputResult:
|
||||
closeInputFD()
|
||||
if result.err != nil {
|
||||
return "", fmt.Errorf("reading line from terminal: %w", result.err)
|
||||
}
|
||||
if result.value == "" && !allowEmpty {
|
||||
return "", fmt.Errorf("value cannot be empty")
|
||||
}
|
||||
return result.value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func readPasswordConfirmed(
|
||||
ctx context.Context,
|
||||
prompt, confirmationPrompt string,
|
||||
) ([]byte, error) {
|
||||
password, err := readSecret(ctx, prompt, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
confirmation, err := readSecret(ctx, confirmationPrompt, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(password) != string(confirmation) {
|
||||
return nil, fmt.Errorf("passwords do not match")
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func deriveKey(password, salt []byte) ([]byte, error) {
|
||||
key, err := scrypt.Key(password, salt, scryptN, scryptR, scryptP, keySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deriving key with scrypt: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func encryptData(plaintext, password []byte) ([]byte, error) {
|
||||
salt := make([]byte, saltSize)
|
||||
_, err := io.ReadFull(rand.Reader, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := deriveKey(password, salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating GCM: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceSize)
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
result := make([]byte, 0, saltSize+nonceSize+len(ciphertext))
|
||||
result = append(result, salt...)
|
||||
result = append(result, nonce...)
|
||||
result = append(result, ciphertext...)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decryptData(data, password []byte) ([]byte, error) {
|
||||
const minSize = saltSize + nonceSize + 16 // 16 is the GCM tag size
|
||||
if len(data) < minSize {
|
||||
return nil, fmt.Errorf("encrypted data too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
salt := data[:saltSize]
|
||||
nonce := data[saltSize : saltSize+nonceSize]
|
||||
ciphertext := data[saltSize+nonceSize:]
|
||||
|
||||
key, err := deriveKey(password, salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating GCM: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting data (wrong password?): %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
@@ -1,351 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/mount"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/docker/docker/pkg/stdcopy"
|
||||
"github.com/docker/go-connections/nat"
|
||||
v1 "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type containerOptions struct {
|
||||
env []string
|
||||
binds []string
|
||||
ports nat.PortMap
|
||||
dns []string
|
||||
devices []container.DeviceMapping
|
||||
labels map[string]string
|
||||
capAdd []string
|
||||
}
|
||||
|
||||
// Run decrypts credentials, builds the container environment, and runs a Gluetun container.
|
||||
// extraArgs is the list of additional flags (e.g. ["-e", "PORT_FORWARDING=on", "-v", "/host:/container"]).
|
||||
func Run(ctx context.Context, provider, vpnType string, extraArgs []string,
|
||||
forceKill <-chan struct{},
|
||||
) error {
|
||||
credentials, err := decryptCredentials(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading credentials: %w", err)
|
||||
}
|
||||
|
||||
credentialEnvVars, err := lookupCredentials(credentials, provider, vpnType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
extraOpts, err := parseExtraArgs(extraArgs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing extra flags: %w", err)
|
||||
}
|
||||
opts := extraOpts
|
||||
opts.env = append(opts.env,
|
||||
"VPN_SERVICE_PROVIDER="+provider,
|
||||
"VPN_TYPE="+vpnType,
|
||||
"LOG_LEVEL=debug",
|
||||
)
|
||||
opts.env = append(opts.env, credentialEnvVars...)
|
||||
opts.capAdd = append(opts.capAdd, "NET_ADMIN")
|
||||
|
||||
return runContainer(ctx, opts, forceKill)
|
||||
}
|
||||
|
||||
// parseExtraArgs parses extra arguments and maps them to container options.
|
||||
// Supported flags:
|
||||
//
|
||||
// -e, --env KEY=VALUE - environment variable
|
||||
// -v, --volume SPEC - volume mount (e.g., "/host:/container" or "name:/container")
|
||||
// -p, --publish PORT:PORT - port mapping
|
||||
// --dns IP - DNS server
|
||||
// --device SPEC - device access (e.g., "/dev/net/tun")
|
||||
// --label KEY=VALUE - container label
|
||||
// --cap-add CAPABILITY - add Linux capability (e.g., "SYS_PTRACE")
|
||||
func parseExtraArgs(args []string) (opts containerOptions, err error) { //nolint:gocognit,gocyclo
|
||||
opts = containerOptions{
|
||||
ports: make(nat.PortMap),
|
||||
labels: make(map[string]string),
|
||||
}
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case arg == "-e" || arg == "--env":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
opts.env = append(opts.env, args[i])
|
||||
case strings.HasPrefix(arg, "-e="):
|
||||
opts.env = append(opts.env, strings.TrimPrefix(arg, "-e="))
|
||||
case strings.HasPrefix(arg, "--env="):
|
||||
opts.env = append(opts.env, strings.TrimPrefix(arg, "--env="))
|
||||
|
||||
case arg == "-v" || arg == "--volume":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
opts.binds = append(opts.binds, args[i])
|
||||
case strings.HasPrefix(arg, "-v="):
|
||||
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "-v="))
|
||||
case strings.HasPrefix(arg, "--volume="):
|
||||
opts.binds = append(opts.binds, strings.TrimPrefix(arg, "--volume="))
|
||||
|
||||
case arg == "-p" || arg == "--publish":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
if err := parsePortMapping(opts.ports, args[i]); err != nil {
|
||||
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||
}
|
||||
case strings.HasPrefix(arg, "-p="):
|
||||
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "-p=")); err != nil {
|
||||
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||
}
|
||||
case strings.HasPrefix(arg, "--publish="):
|
||||
if err := parsePortMapping(opts.ports, strings.TrimPrefix(arg, "--publish=")); err != nil {
|
||||
return opts, fmt.Errorf("parsing port mapping: %w", err)
|
||||
}
|
||||
|
||||
case arg == "--dns":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
opts.dns = append(opts.dns, args[i])
|
||||
case strings.HasPrefix(arg, "--dns="):
|
||||
opts.dns = append(opts.dns, strings.TrimPrefix(arg, "--dns="))
|
||||
|
||||
case arg == "--device":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
parseDeviceMapping(&opts.devices, args[i])
|
||||
case strings.HasPrefix(arg, "--device="):
|
||||
parseDeviceMapping(&opts.devices, strings.TrimPrefix(arg, "--device="))
|
||||
|
||||
case arg == "--label":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
parseLabel(opts.labels, args[i])
|
||||
case strings.HasPrefix(arg, "--label="):
|
||||
parseLabel(opts.labels, strings.TrimPrefix(arg, "--label="))
|
||||
|
||||
case arg == "--cap-add":
|
||||
if i+1 >= len(args) {
|
||||
return opts, fmt.Errorf("flag %q requires an argument", arg)
|
||||
}
|
||||
i++
|
||||
opts.capAdd = append(opts.capAdd, args[i])
|
||||
case strings.HasPrefix(arg, "--cap-add="):
|
||||
opts.capAdd = append(opts.capAdd, strings.TrimPrefix(arg, "--cap-add="))
|
||||
|
||||
default:
|
||||
return opts, fmt.Errorf("unsupported flag %q", arg)
|
||||
}
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func parsePortMapping(portMap nat.PortMap, spec string) error {
|
||||
port, bindings, err := nat.ParsePortSpecs([]string{spec})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for p, binding := range bindings {
|
||||
portMap[p] = binding
|
||||
}
|
||||
for p := range port {
|
||||
if _, exists := portMap[p]; !exists {
|
||||
portMap[p] = []nat.PortBinding{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDeviceMapping(devices *[]container.DeviceMapping, spec string) {
|
||||
parts := strings.SplitN(spec, ":", 3) //nolint:mnd
|
||||
pathOnHost := parts[0]
|
||||
pathInContainer := pathOnHost
|
||||
permissions := "rwm"
|
||||
|
||||
if len(parts) >= 2 { //nolint:mnd
|
||||
pathInContainer = parts[1]
|
||||
}
|
||||
if len(parts) >= 3 { //nolint:mnd
|
||||
permissions = parts[2]
|
||||
}
|
||||
|
||||
*devices = append(*devices, container.DeviceMapping{
|
||||
PathOnHost: pathOnHost,
|
||||
PathInContainer: pathInContainer,
|
||||
CgroupPermissions: permissions,
|
||||
})
|
||||
}
|
||||
|
||||
func parseLabel(labels map[string]string, kv string) {
|
||||
parts := strings.SplitN(kv, "=", 2) //nolint:mnd
|
||||
key := parts[0]
|
||||
value := ""
|
||||
if len(parts) > 1 {
|
||||
value = parts[1]
|
||||
}
|
||||
labels[key] = value
|
||||
}
|
||||
|
||||
func runContainer(ctx context.Context, opts containerOptions, forceKill <-chan struct{}) error {
|
||||
dockerClient, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating docker client: %w", err)
|
||||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
hasTTY := term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
containerConfig := &container.Config{
|
||||
Image: "qmcgaw/gluetun",
|
||||
Env: opts.env,
|
||||
Labels: opts.labels,
|
||||
Tty: hasTTY,
|
||||
}
|
||||
|
||||
mounts := make([]mount.Mount, 0, len(opts.binds))
|
||||
for _, bind := range opts.binds {
|
||||
m, err := parseBindMount(bind)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bind mount %q: %w", bind, err)
|
||||
}
|
||||
mounts = append(mounts, m)
|
||||
}
|
||||
|
||||
hostConfig := &container.HostConfig{
|
||||
AutoRemove: true,
|
||||
CapAdd: opts.capAdd,
|
||||
Binds: opts.binds,
|
||||
Mounts: mounts,
|
||||
PortBindings: opts.ports,
|
||||
DNS: opts.dns,
|
||||
}
|
||||
hostConfig.Devices = opts.devices
|
||||
|
||||
networkConfig := &network.NetworkingConfig{}
|
||||
|
||||
platform := (*v1.Platform)(nil)
|
||||
|
||||
const containerName = "gluetun"
|
||||
response, err := dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, networkConfig, platform, containerName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating container: %w", err)
|
||||
}
|
||||
for _, warning := range response.Warnings {
|
||||
fmt.Fprintln(os.Stderr, "container creation warning:", warning)
|
||||
}
|
||||
containerID := response.ID
|
||||
|
||||
err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting container: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Container started (id: %.12s)\n", containerID)
|
||||
|
||||
streamLogsErr := make(chan error, 1)
|
||||
go func() {
|
||||
streamLogsErr <- streamLogs(context.Background(), dockerClient, containerID, hasTTY)
|
||||
}()
|
||||
|
||||
contextDone := ctx.Done()
|
||||
forceKillSignal := forceKill
|
||||
for {
|
||||
select {
|
||||
case err := <-streamLogsErr:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case <-contextDone:
|
||||
fmt.Fprintln(os.Stderr, "\nReceived interrupt, stopping container (5s timeout)...")
|
||||
err = stopContainer(dockerClient, containerID)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "stopping container:", err)
|
||||
}
|
||||
contextDone = nil
|
||||
case <-forceKillSignal:
|
||||
fmt.Fprintln(os.Stderr, "\nReceived second interrupt, killing container...")
|
||||
err = killContainer(dockerClient, containerID)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "killing container:", err)
|
||||
}
|
||||
forceKillSignal = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseBindMount(bind string) (mount.Mount, error) {
|
||||
parts := strings.SplitN(bind, ":", 3) //nolint:mnd
|
||||
if len(parts) < 2 { //nolint:mnd
|
||||
return mount.Mount{}, fmt.Errorf("invalid bind mount format: %q (expected source:target[:mode])", bind)
|
||||
}
|
||||
|
||||
source := parts[0]
|
||||
target := parts[1]
|
||||
readOnly := len(parts) > 2 && strings.Contains(parts[2], "ro") //nolint:mnd
|
||||
|
||||
return mount.Mount{
|
||||
Type: mount.TypeBind,
|
||||
Source: source,
|
||||
Target: target,
|
||||
ReadOnly: readOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func stopContainer(dockerClient *client.Client, containerID string) error {
|
||||
const stopTimeout = 5 * time.Second
|
||||
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
|
||||
defer stopCancel()
|
||||
timeoutSeconds := int(stopTimeout.Seconds())
|
||||
return dockerClient.ContainerStop(stopCtx, containerID, container.StopOptions{Timeout: &timeoutSeconds})
|
||||
}
|
||||
|
||||
func killContainer(dockerClient *client.Client, containerID string) error {
|
||||
return dockerClient.ContainerKill(context.Background(), containerID, "KILL")
|
||||
}
|
||||
|
||||
func streamLogs(ctx context.Context, dockerClient *client.Client, containerID string, hasTTY bool) error {
|
||||
logOptions := container.LogsOptions{
|
||||
ShowStdout: true,
|
||||
ShowStderr: true,
|
||||
Follow: true,
|
||||
Timestamps: false,
|
||||
}
|
||||
|
||||
reader, err := dockerClient.ContainerLogs(ctx, containerID, logOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting container logs: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if hasTTY {
|
||||
_, err = io.Copy(os.Stdout, reader)
|
||||
} else {
|
||||
_, err = stdcopy.StdCopy(os.Stdout, os.Stderr, reader)
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
return fmt.Errorf("streaming container logs: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -15,7 +15,6 @@ require (
|
||||
github.com/mdlayher/netlink v1.9.0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a
|
||||
github.com/qdm12/gluetun-servers v0.1.0
|
||||
github.com/qdm12/gosettings v0.4.4
|
||||
github.com/qdm12/goshutdown v0.3.0
|
||||
github.com/qdm12/gosplash v0.2.1-0.20260305164749-b713de4fee6c
|
||||
@@ -27,7 +26,6 @@ require (
|
||||
github.com/ulikunitz/xz v0.5.15
|
||||
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
|
||||
golang.org/x/mod v0.33.0
|
||||
golang.org/x/net v0.51.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/text v0.35.0
|
||||
@@ -59,6 +57,7 @@ require (
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
|
||||
@@ -76,8 +76,6 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a h1:TE157yPQmAbVruH0MWCQzs0vTT/6t96DkoWUXd6PVuc=
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||
github.com/qdm12/gluetun-servers v0.1.0 h1:w9JLghKZwI0Gzpp9p5rNANgEYUUZ1dxdxsG6NKIojaY=
|
||||
github.com/qdm12/gluetun-servers v0.1.0/go.mod h1:acttuyHyoFDu6GTbf3kAV+QXeiX8oJeh0MBic67/9z8=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
|
||||
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package alpine
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var ErrUserAlreadyExists = errors.New("user already exists")
|
||||
|
||||
// CreateUser creates a user in Alpine with the given UID.
|
||||
func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, err error) {
|
||||
UIDStr := strconv.Itoa(uid)
|
||||
@@ -31,8 +34,8 @@ func (a *Alpine) CreateUser(username string, uid int) (createdUsername string, e
|
||||
}
|
||||
|
||||
if u != nil {
|
||||
return "", fmt.Errorf("user already exists: with name %s for ID %s instead of %d",
|
||||
username, u.Uid, uid)
|
||||
return "", fmt.Errorf("%w: with name %s for ID %s instead of %d",
|
||||
ErrUserAlreadyExists, username, u.Uid, uid)
|
||||
}
|
||||
|
||||
const permission = fs.FileMode(0o644)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package amneziawg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -29,7 +28,7 @@ func Test_New(t *testing.T) {
|
||||
PrivateKey: "",
|
||||
},
|
||||
},
|
||||
err: errors.New("private key is missing"),
|
||||
err: wireguard.ErrPrivateKeyMissing,
|
||||
},
|
||||
"minimal valid settings": {
|
||||
settings: Settings{
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
var (
|
||||
errTunNameMismatch = errors.New("TUN device name is mismatching")
|
||||
errDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
// Run runs the amneziawg interface and waits until the context is done, then it cleans up the
|
||||
// interface and returns any error that occurred during setup or waiting. It sends an error to
|
||||
// waitError if any error occurs during setup or waiting, otherwise it sends nil when the context
|
||||
@@ -47,7 +52,8 @@ func setupUserspace(ctx context.Context,
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("getting created TUN device name: %w", err)
|
||||
} else if tunName != interfaceName {
|
||||
return 0, nil, fmt.Errorf("TUN device name is mismatching: expected %q and got %q", interfaceName, tunName)
|
||||
return 0, nil, fmt.Errorf("%w: expected %q and got %q",
|
||||
errTunNameMismatch, interfaceName, tunName)
|
||||
}
|
||||
|
||||
link, err := netLinker.LinkByName(interfaceName)
|
||||
@@ -100,7 +106,7 @@ func setupUserspace(ctx context.Context,
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = errors.New("device waited for")
|
||||
err = errDeviceWaited
|
||||
}
|
||||
|
||||
cleanups.Cleanup(logger)
|
||||
|
||||
@@ -31,7 +31,7 @@ type urlData struct{}
|
||||
func New(client *http.Client, logger Logger, settings settings.BoringPoll) *BoringPoll {
|
||||
urlToData := make(map[string]*urlData)
|
||||
if *settings.GluetunCom {
|
||||
logger.Infof("gluetun.com is DOWN most likely thanks to you! so not doing anything anymore")
|
||||
urlToData["https://gluetun.com/wp-json"] = &urlData{}
|
||||
}
|
||||
return &BoringPoll{
|
||||
client: client,
|
||||
|
||||
+6
-2
@@ -1,7 +1,11 @@
|
||||
package cli
|
||||
|
||||
type CLI struct{}
|
||||
type CLI struct {
|
||||
repoServersPath string
|
||||
}
|
||||
|
||||
func New() *CLI {
|
||||
return &CLI{}
|
||||
return &CLI{
|
||||
repoServersPath: "./internal/storage/servers.json",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
||||
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
|
||||
)
|
||||
|
||||
func addProviderFlag(flagSet *flag.FlagSet, providerToFormat map[string]*bool,
|
||||
provider string, titleCaser cases.Caser,
|
||||
) {
|
||||
@@ -58,10 +65,11 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
}
|
||||
switch len(providers) {
|
||||
case 0:
|
||||
return errors.New("VPN provider to format was not specified")
|
||||
return fmt.Errorf("%w", ErrProviderUnspecified)
|
||||
case 1:
|
||||
default:
|
||||
return fmt.Errorf("more than one VPN provider to format were specified: %d specified: %s", len(providers),
|
||||
return fmt.Errorf("%w: %d specified: %s",
|
||||
ErrMultipleProvidersToFormat, len(providers),
|
||||
strings.Join(providers, ", "))
|
||||
}
|
||||
|
||||
@@ -72,9 +80,10 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := setupStorage(newNoopLogger())
|
||||
logger := newNoopLogger()
|
||||
storage, err := storage.New(logger, constants.ServersData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up storage: %w", err)
|
||||
return fmt.Errorf("creating servers storage: %w", err)
|
||||
}
|
||||
|
||||
formatted, err := storage.Format(providerToFormat, format)
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/configuration/sources/files"
|
||||
"github.com/qdm12/gluetun/internal/configuration/sources/secrets"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gosettings/reader/sources/env"
|
||||
)
|
||||
|
||||
type storageSetupLogger interface {
|
||||
storage.Logger
|
||||
files.Warner
|
||||
}
|
||||
|
||||
func setupStorage(logger storageSetupLogger) (s *storage.Storage, err error) {
|
||||
settingsReader := reader.New(reader.Settings{
|
||||
Sources: []reader.Source{
|
||||
secrets.New(logger),
|
||||
files.New(logger),
|
||||
env.New(env.Settings{}),
|
||||
},
|
||||
})
|
||||
var settings settings.Storage
|
||||
err = settings.Read(settingsReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading storage settings: %w", err)
|
||||
}
|
||||
settings.SetDefaults()
|
||||
storage, err := storage.New(logger, *settings.ServersEnabled, settings.ServersPath,
|
||||
settings.LegacyServersFilepath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating storage: %w", err)
|
||||
}
|
||||
return storage, nil
|
||||
}
|
||||
@@ -6,7 +6,5 @@ func newNoopLogger() *noopLogger {
|
||||
return new(noopLogger)
|
||||
}
|
||||
|
||||
func (l *noopLogger) Info(string) {}
|
||||
func (l *noopLogger) Infof(string, ...any) {}
|
||||
func (l *noopLogger) Warn(string) {}
|
||||
func (l *noopLogger) Warnf(string, ...any) {}
|
||||
func (l *noopLogger) Info(string) {}
|
||||
func (l *noopLogger) Warn(string) {}
|
||||
|
||||
@@ -9,10 +9,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/openvpn/extract"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater/resolver"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
)
|
||||
@@ -47,9 +49,9 @@ type IPv6Checker interface {
|
||||
func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
|
||||
ipv6Checker IPv6Checker,
|
||||
) error {
|
||||
storage, err := setupStorage(newNoopLogger())
|
||||
storage, err := storage.New(logger, constants.ServersData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up storage: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var allSettings settings.Settings
|
||||
|
||||
+32
-16
@@ -13,30 +13,38 @@ import (
|
||||
"github.com/qdm12/dns/v2/pkg/doh"
|
||||
dnsprovider "github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/openvpn/extract"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/publicip/api"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
"github.com/qdm12/gluetun/internal/updater/resolver"
|
||||
"github.com/qdm12/gluetun/internal/updater/unzip"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
|
||||
ErrNoProviderSpecified = errors.New("no provider was specified")
|
||||
ErrUsernameMissing = errors.New("username is required for this provider")
|
||||
ErrPasswordMissing = errors.New("password is required for this provider")
|
||||
)
|
||||
|
||||
type UpdaterLogger interface {
|
||||
Info(s string)
|
||||
Infof(format string, args ...any)
|
||||
Warn(s string)
|
||||
Warnf(format string, args ...any)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error {
|
||||
options := settings.Updater{}
|
||||
// TODO v4: remove flags below already present in standard settings
|
||||
var endUserMode, maintainerMode bool
|
||||
var updateAll bool
|
||||
var endUserMode, maintainerMode, updateAll bool
|
||||
var dnsServer, csvProviders, ipToken, protonUsername, protonEmail, protonPassword string
|
||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||
flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)")
|
||||
flagSet.BoolVar(&maintainerMode, "maintainer", false,
|
||||
"Write results to ./internal/storage/servers.json to modify the program (for maintainers)")
|
||||
flagSet.StringVar(&dnsServer, "dns", "", "no longer used, your DNS will use DoH with Cloudflare and Google")
|
||||
const defaultMinRatio = 0.8
|
||||
flagSet.Float64Var(&options.MinRatio, "minratio", defaultMinRatio,
|
||||
@@ -48,26 +56,23 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
"(Retro-compatibility) Username to use to authenticate with Proton. Use -proton-email instead.") // v4 remove this
|
||||
flagSet.StringVar(&protonEmail, "proton-email", "", "Email to use to authenticate with Proton")
|
||||
flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton")
|
||||
flagSet.BoolVar(&endUserMode, "enduser", false, "deprecated")
|
||||
flagSet.BoolVar(&maintainerMode, "maintainer", false, "deprecated")
|
||||
if err := flagSet.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case dnsServer != "":
|
||||
if dnsServer != "" {
|
||||
logger.Warn("The -dns flag is no longer used, your DNS will use DoH with Cloudflare and Google")
|
||||
case endUserMode:
|
||||
logger.Warn("The -enduser flag is now unused")
|
||||
case maintainerMode:
|
||||
logger.Warn("The -maintainer flag is now unused")
|
||||
}
|
||||
|
||||
if !endUserMode && !maintainerMode {
|
||||
return fmt.Errorf("%w", ErrModeUnspecified)
|
||||
}
|
||||
|
||||
if updateAll {
|
||||
options.Providers = providers.All()
|
||||
} else {
|
||||
if csvProviders == "" {
|
||||
return errors.New("no provider was specified")
|
||||
return fmt.Errorf("%w", ErrNoProviderSpecified)
|
||||
}
|
||||
options.Providers = strings.Split(csvProviders, ",")
|
||||
}
|
||||
@@ -89,7 +94,11 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
return fmt.Errorf("options validation failed: %w", err)
|
||||
}
|
||||
|
||||
storage, err := setupStorage(logger)
|
||||
serversDataPath := constants.ServersData
|
||||
if maintainerMode {
|
||||
serversDataPath = ""
|
||||
}
|
||||
storage, err := storage.New(logger, serversDataPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating servers storage: %w", err)
|
||||
}
|
||||
@@ -125,11 +134,18 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
providers := provider.NewProviders(storage, time.Now, logger, httpClient,
|
||||
unzipper, parallelResolver, ipFetcher, openvpnFileExtractor, options)
|
||||
|
||||
updater := updater.New(httpClient, storage, providers, logger, *options.PreferDirectDownload)
|
||||
updater := updater.New(httpClient, storage, providers, logger)
|
||||
err = updater.UpdateServers(ctx, options.Providers, options.MinRatio)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating server information: %w", err)
|
||||
}
|
||||
|
||||
if maintainerMode {
|
||||
err := storage.FlushToFile(c.repoServersPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing servers data to embedded JSON file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,13 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
errCommandEmpty = errors.New("command is empty")
|
||||
errSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
|
||||
errDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
|
||||
errEscapeUnterminated = errors.New("unterminated backslash-escape")
|
||||
)
|
||||
|
||||
// split splits a command string into a slice of arguments.
|
||||
// This is especially important for commands such as:
|
||||
// /bin/sh -c "echo hello"
|
||||
@@ -18,7 +25,7 @@ import (
|
||||
// - expansion (brace, shell or pathname).
|
||||
func split(command string) (words []string, err error) {
|
||||
if command == "" {
|
||||
return nil, errors.New("command is empty")
|
||||
return nil, fmt.Errorf("%w", errCommandEmpty)
|
||||
}
|
||||
|
||||
const bufferSize = 1024
|
||||
@@ -35,7 +42,7 @@ func split(command string) (words []string, err error) {
|
||||
case character == '\\':
|
||||
// Look ahead to eventually skip an escaped newline
|
||||
if command[startIndex+runeSize:] == "" {
|
||||
return nil, fmt.Errorf("unterminated backslash-escape: %q", command)
|
||||
return nil, fmt.Errorf("%w: %q", errEscapeUnterminated, command)
|
||||
}
|
||||
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
|
||||
if character == '\n' {
|
||||
@@ -112,7 +119,7 @@ func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
startIndex = cursor
|
||||
}
|
||||
}
|
||||
return "", 0, errors.New("unterminated double-quoted string")
|
||||
return "", 0, fmt.Errorf("%w", errDoubleQuoteUnterminated)
|
||||
}
|
||||
|
||||
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
@@ -120,7 +127,7 @@ func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
) {
|
||||
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
|
||||
if closingQuoteIndex == -1 {
|
||||
return "", 0, errors.New("unterminated single-quoted string")
|
||||
return "", 0, fmt.Errorf("%w", errSingleQuoteUnterminated)
|
||||
}
|
||||
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
|
||||
const singleQuoteRuneLength = 1
|
||||
@@ -132,7 +139,7 @@ func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
|
||||
word string, newStartIndex int, err error,
|
||||
) {
|
||||
if input[startIndex:] == "" {
|
||||
return "", 0, errors.New("unterminated backslash-escape")
|
||||
return "", 0, fmt.Errorf("%w", errEscapeUnterminated)
|
||||
}
|
||||
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
|
||||
if character != '\n' { // backslash-escaped newline is ignored
|
||||
|
||||
@@ -12,10 +12,12 @@ func Test_split(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
command string
|
||||
words []string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
command: "",
|
||||
errWrapped: errCommandEmpty,
|
||||
errMessage: "command is empty",
|
||||
},
|
||||
"concrete_sh_command": {
|
||||
@@ -72,18 +74,22 @@ func Test_split(t *testing.T) {
|
||||
},
|
||||
"unterminated_single_quote": {
|
||||
command: "'abc'\\''def",
|
||||
errWrapped: errSingleQuoteUnterminated,
|
||||
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
|
||||
},
|
||||
"unterminated_double_quote": {
|
||||
command: "\"abc'def",
|
||||
errWrapped: errDoubleQuoteUnterminated,
|
||||
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
|
||||
},
|
||||
"unterminated_escape": {
|
||||
command: "abc\\",
|
||||
errWrapped: errEscapeUnterminated,
|
||||
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
|
||||
},
|
||||
"unterminated_escape_only": {
|
||||
command: " \\",
|
||||
errWrapped: errEscapeUnterminated,
|
||||
errMessage: `unterminated backslash-escape: " \\"`,
|
||||
},
|
||||
}
|
||||
@@ -95,10 +101,9 @@ func Test_split(t *testing.T) {
|
||||
words, err := split(testCase.command)
|
||||
|
||||
assert.Equal(t, testCase.words, words)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorContains(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ func (c *Cmder) Start(cmd *exec.Cmd) (
|
||||
func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
|
||||
waitError <-chan error, startErr error,
|
||||
) {
|
||||
stop := make(chan struct{})
|
||||
stdoutReady := make(chan struct{})
|
||||
stdoutLinesCh := make(chan string)
|
||||
stdoutDone := make(chan struct{})
|
||||
@@ -32,20 +33,22 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
go streamToChannel(stdoutReady, stdoutDone, stdout, stdoutLinesCh)
|
||||
go streamToChannel(stdoutReady, stop, stdoutDone, stdout, stdoutLinesCh)
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
_ = stdout.Close()
|
||||
close(stop)
|
||||
<-stdoutDone
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
go streamToChannel(stderrReady, stderrDone, stderr, stderrLinesCh)
|
||||
go streamToChannel(stderrReady, stop, stderrDone, stderr, stderrLinesCh)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
_ = stdout.Close()
|
||||
_ = stderr.Close()
|
||||
close(stop)
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
return nil, nil, nil, err
|
||||
@@ -54,20 +57,19 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
|
||||
waitErrorCh := make(chan error)
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
_ = stdout.Close()
|
||||
_ = stderr.Close()
|
||||
close(stop)
|
||||
<-stdoutDone
|
||||
<-stderrDone
|
||||
waitErrorCh <- err
|
||||
}()
|
||||
|
||||
<-stdoutReady
|
||||
<-stderrReady
|
||||
|
||||
return stdoutLinesCh, stderrLinesCh, waitErrorCh, nil
|
||||
}
|
||||
|
||||
func streamToChannel(ready chan<- struct{}, done chan<- struct{},
|
||||
func streamToChannel(ready chan<- struct{},
|
||||
stop <-chan struct{}, done chan<- struct{},
|
||||
stream io.Reader, lines chan<- string,
|
||||
) {
|
||||
defer close(done)
|
||||
@@ -87,5 +89,12 @@ func streamToChannel(ready chan<- struct{}, done chan<- struct{},
|
||||
if err == nil || errors.Is(err, os.ErrClosed) {
|
||||
return
|
||||
}
|
||||
lines <- "stream error: " + err.Error()
|
||||
|
||||
// ignore the error if it is stopped.
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
lines <- "stream error: " + err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -176,6 +177,14 @@ func (a AmneziaWg) toLinesNode() (node *gotree.Node) {
|
||||
return node
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAmenziawgImplementationNotValid = errors.New("AmneziaWG implementation is not valid")
|
||||
ErrJunkPacketBounds = errors.New("junk packet minimum must be lower than or equal to maximum")
|
||||
ErrJunkPacketMinMaxNotSet = errors.New("junk packet min and max must be set when junk packet count is set")
|
||||
ErrJunkPacketCountNotSet = errors.New("junk packet count must be set when junk packet min or max is set")
|
||||
ErrHeaderRangeMalformed = errors.New("header range is malformed")
|
||||
)
|
||||
|
||||
func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
const amneziaWG = true
|
||||
err := a.Wireguard.validate(vpnProvider, ipv6Supported, amneziaWG)
|
||||
@@ -185,16 +194,16 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
|
||||
if *a.JunkPacketCount == 0 {
|
||||
if *a.JunkPacketMin != 0 || *a.JunkPacketMax != 0 {
|
||||
return fmt.Errorf("junk packet count must be set when junk packet min or max is set: "+
|
||||
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketCountNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
} else {
|
||||
if *a.JunkPacketMin == 0 || *a.JunkPacketMax == 0 {
|
||||
return fmt.Errorf("junk packet min and max must be set when junk packet count is set: "+
|
||||
"jc=%d and jmin=%d and jmax=%d", a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("%w: jc=%d and jmin=%d and jmax=%d",
|
||||
ErrJunkPacketMinMaxNotSet, a.JunkPacketCount, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
} else if *a.JunkPacketMin > *a.JunkPacketMax {
|
||||
return fmt.Errorf("junk packet minimum must be lower than or equal to maximum: "+
|
||||
"jmin=%d and jmax=%d", *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
return fmt.Errorf("%w: jmin=%d and jmax=%d",
|
||||
ErrJunkPacketBounds, *a.JunkPacketMin, *a.JunkPacketMax)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,20 +222,20 @@ func (a AmneziaWg) validate(vpnProvider string, ipv6Supported bool) error {
|
||||
case 1:
|
||||
_, err := strconv.Atoi(fields[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s is not a number", name, headerRange)
|
||||
return fmt.Errorf("%w: %s value %s is not a number",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
}
|
||||
case 2: //nolint:mnd
|
||||
for _, field := range fields {
|
||||
_, err := strconv.Atoi(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s is not a valid range", name, headerRange)
|
||||
return fmt.Errorf("%w: %s value %s is not a valid range",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("header range is malformed: "+
|
||||
"%s value %s must be in the form n or n-m", name, headerRange)
|
||||
return fmt.Errorf("%w: %s value %s must be in the form n or n-m",
|
||||
ErrHeaderRangeMalformed, name, headerRange)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,10 @@ func readObsolete(r *reader.Reader) (warnings []string) {
|
||||
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
|
||||
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
|
||||
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
|
||||
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because you should use the built-in server which now " +
|
||||
"forwards local names to private DNS resolvers found in /etc/resolv.conf at container start",
|
||||
"DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.",
|
||||
"DOT": "DOT is obsolete because the forwarding server is always enabled.",
|
||||
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " +
|
||||
"forwards local names to private DNS resolvers found in /etc/resolv.conf",
|
||||
}
|
||||
sortedKeys := maps.Keys(keyToMessage)
|
||||
slices.Sort(sortedKeys)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -20,9 +21,6 @@ const (
|
||||
|
||||
// DNS contains settings to configure DNS.
|
||||
type DNS struct {
|
||||
// ServerEnabled indicates if the DNS server should be enabled.
|
||||
// It defaults to true and cannot be nil in the internal state.
|
||||
ServerEnabled *bool `json:"enabled"`
|
||||
// UpstreamType can be [DNSUpstreamTypeDot], [DNSUpstreamTypeDoh]
|
||||
// or [DNSUpstreamTypePlain]. It defaults to [DNSUpstreamTypeDot].
|
||||
UpstreamType string `json:"upstream_type"`
|
||||
@@ -50,22 +48,22 @@ type DNS struct {
|
||||
UpstreamPlainAddresses []netip.AddrPort
|
||||
}
|
||||
|
||||
var (
|
||||
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
|
||||
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
|
||||
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
|
||||
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
|
||||
)
|
||||
|
||||
func (d DNS) validate() (err error) {
|
||||
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
|
||||
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
|
||||
}
|
||||
|
||||
if !*d.ServerEnabled {
|
||||
err = d.validateForServerOff()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
|
||||
}
|
||||
|
||||
const minUpdatePeriod = 30 * time.Second
|
||||
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
|
||||
return fmt.Errorf("update period is too short: %s must be bigger than %s",
|
||||
*d.UpdatePeriod, minUpdatePeriod)
|
||||
return fmt.Errorf("%w: %s must be bigger than %s",
|
||||
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
|
||||
}
|
||||
|
||||
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||
@@ -83,11 +81,9 @@ func (d DNS) validate() (err error) {
|
||||
}
|
||||
switch {
|
||||
case *d.IPv6 && !selectedHasPlainIPv6:
|
||||
return fmt.Errorf("upstream plain addresses do not contain any IPv6 address: "+
|
||||
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
|
||||
case !*d.IPv6 && !selectedHasPlainIPv4:
|
||||
return fmt.Errorf("upstream plain addresses do not contain any IPv4 address: "+
|
||||
"in %d addresses", len(d.UpstreamPlainAddresses))
|
||||
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
|
||||
}
|
||||
}
|
||||
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
|
||||
@@ -100,26 +96,8 @@ func (d DNS) validate() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d DNS) validateForServerOff() (err error) {
|
||||
switch {
|
||||
case d.UpstreamType != DNSUpstreamTypePlain:
|
||||
return fmt.Errorf("upstream type %s must be %s if the built-in DNS server is disabled",
|
||||
d.UpstreamType, DNSUpstreamTypePlain)
|
||||
case len(d.UpstreamPlainAddresses) == 0:
|
||||
return fmt.Errorf("if DNS is disabled, at least one upstream plain address must be set")
|
||||
}
|
||||
for _, addrPort := range d.UpstreamPlainAddresses {
|
||||
const defaultDNSPort = 53
|
||||
if addrPort.Port() != defaultDNSPort {
|
||||
return fmt.Errorf("invalid DNS port in %s: must be %d", addrPort, defaultDNSPort)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DNS) Copy() (copied DNS) {
|
||||
return DNS{
|
||||
ServerEnabled: gosettings.CopyPointer(d.ServerEnabled),
|
||||
UpstreamType: d.UpstreamType,
|
||||
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
|
||||
Providers: gosettings.CopySlice(d.Providers),
|
||||
@@ -134,7 +112,6 @@ func (d *DNS) Copy() (copied DNS) {
|
||||
// settings object with any field set in the other
|
||||
// settings.
|
||||
func (d *DNS) overrideWith(other DNS) {
|
||||
d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled)
|
||||
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
|
||||
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
|
||||
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
|
||||
@@ -145,12 +122,7 @@ func (d *DNS) overrideWith(other DNS) {
|
||||
}
|
||||
|
||||
func (d *DNS) setDefaults() {
|
||||
d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true)
|
||||
defaultUpstreamType := DNSUpstreamTypeDot
|
||||
if !*d.ServerEnabled {
|
||||
defaultUpstreamType = DNSUpstreamTypePlain
|
||||
}
|
||||
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, defaultUpstreamType)
|
||||
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot)
|
||||
const defaultUpdatePeriod = 24 * time.Hour
|
||||
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
|
||||
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
|
||||
@@ -173,14 +145,6 @@ func (d DNS) String() string {
|
||||
func (d DNS) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("DNS settings:")
|
||||
|
||||
if !*d.ServerEnabled {
|
||||
plainServers := node.Append("Plain DNS servers to use directly:")
|
||||
for _, addr := range d.UpstreamPlainAddresses {
|
||||
plainServers.Append(addr.String())
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
node.Appendf("Upstream resolver type: %s", d.UpstreamType)
|
||||
|
||||
upstreamResolvers := node.Append("Upstream resolvers:")
|
||||
@@ -216,11 +180,6 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
|
||||
func (d *DNS) read(r *reader.Reader) (err error) {
|
||||
d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
|
||||
|
||||
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
|
||||
@@ -254,7 +213,7 @@ func (d *DNS) read(r *reader.Reader) (err error) {
|
||||
}
|
||||
|
||||
func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
|
||||
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_RESOLVER_TYPE=plain
|
||||
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain
|
||||
// for these to be used. This is an added safety measure to reduce misunderstandings, and
|
||||
// reduce odd settings overrides.
|
||||
d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -36,16 +37,22 @@ func (b *DNSBlacklist) setDefaults() {
|
||||
|
||||
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
|
||||
|
||||
var (
|
||||
ErrAllowedHostNotValid = errors.New("allowed host is not valid")
|
||||
ErrBlockedHostNotValid = errors.New("blocked host is not valid")
|
||||
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
|
||||
)
|
||||
|
||||
func (b DNSBlacklist) validate() (err error) {
|
||||
for _, host := range b.AllowedHosts {
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("allowed host is not valid: %s", host)
|
||||
return fmt.Errorf("%w: %s", ErrAllowedHostNotValid, host)
|
||||
}
|
||||
}
|
||||
|
||||
for _, host := range b.AddBlockedHosts {
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("blocked host is not valid: %s", host)
|
||||
return fmt.Errorf("%w: %s", ErrBlockedHostNotValid, host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +61,7 @@ func (b DNSBlacklist) validate() (err error) {
|
||||
host = host[2:]
|
||||
}
|
||||
if !hostRegex.MatchString(host) {
|
||||
return fmt.Errorf("rebinding protection exempt host is not valid: %s", host)
|
||||
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,6 +209,8 @@ func readDNSBlockedIPs(r *reader.Reader) (ips []netip.Addr,
|
||||
return ips, ipPrefixes, nil
|
||||
}
|
||||
|
||||
var ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range")
|
||||
|
||||
func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
||||
ipPrefixes []netip.Prefix, err error,
|
||||
) {
|
||||
@@ -227,9 +236,8 @@ func readDNSPrivateAddresses(r *reader.Reader) (ips []netip.Addr,
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf(
|
||||
"environment variable DOT_PRIVATE_ADDRESS: "+
|
||||
"private address is not a valid IP or CIDR range: %s",
|
||||
privateAddress)
|
||||
"environment variable DOT_PRIVATE_ADDRESS: %w: %s",
|
||||
ErrPrivateAddressNotValid, privateAddress)
|
||||
}
|
||||
|
||||
return ips, ipPrefixes, nil
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package settings
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrValueUnknown = errors.New("value is unknown")
|
||||
ErrCityNotValid = errors.New("the city specified is not valid")
|
||||
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
||||
ErrCategoryNotValid = errors.New("the category specified is not valid")
|
||||
ErrCountryNotValid = errors.New("the country specified is not valid")
|
||||
ErrFilepathMissing = errors.New("filepath is missing")
|
||||
ErrFirewallZeroPort = errors.New("cannot have a zero port")
|
||||
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet has an unspecified address")
|
||||
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
||||
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
||||
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
|
||||
ErrMissingValue = errors.New("missing value")
|
||||
ErrNameNotValid = errors.New("the server name specified is not valid")
|
||||
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
|
||||
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
|
||||
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
|
||||
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
|
||||
ErrOpenVPNKeyPassphraseIsEmpty = errors.New("key passphrase is empty")
|
||||
ErrOpenVPNMSSFixIsTooHigh = errors.New("mssfix option value is too high")
|
||||
ErrOpenVPNPasswordIsEmpty = errors.New("password is empty")
|
||||
ErrOpenVPNTCPNotSupported = errors.New("TCP protocol is not supported")
|
||||
ErrOpenVPNUserIsEmpty = errors.New("user is empty")
|
||||
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
|
||||
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
|
||||
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
|
||||
ErrPortForwardingUserEmpty = errors.New("port forwarding username is empty")
|
||||
ErrPortForwardingPasswordEmpty = errors.New("port forwarding password is empty")
|
||||
ErrRegionNotValid = errors.New("the region specified is not valid")
|
||||
ErrServerAddressNotValid = errors.New("server listening address is not valid")
|
||||
ErrSystemPGIDNotValid = errors.New("process group id is not valid")
|
||||
ErrSystemPUIDNotValid = errors.New("process user id is not valid")
|
||||
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
|
||||
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
||||
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
|
||||
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
|
||||
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
||||
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
||||
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
||||
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
|
||||
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
|
||||
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
|
||||
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
|
||||
ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
|
||||
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
|
||||
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
|
||||
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
|
||||
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
|
||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
||||
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
|
||||
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -21,16 +20,16 @@ type Firewall struct {
|
||||
|
||||
func (f Firewall) validate() (err error) {
|
||||
if hasZeroPort(f.VPNInputPorts) {
|
||||
return errors.New("VPN input ports: cannot have a zero port")
|
||||
return fmt.Errorf("VPN input ports: %w", ErrFirewallZeroPort)
|
||||
}
|
||||
|
||||
if hasZeroPort(f.InputPorts) {
|
||||
return errors.New("input ports: cannot have a zero port")
|
||||
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
|
||||
}
|
||||
|
||||
for _, subnet := range f.OutboundSubnets {
|
||||
if subnet.Addr().IsUnspecified() {
|
||||
return fmt.Errorf("outbound subnet has an unspecified address: %s", subnet)
|
||||
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,21 +13,25 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
firewall Firewall
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errWrapped: log.ErrLevelNotRecognized,
|
||||
errMessage: "iptables settings: log level: level is not recognized: ",
|
||||
},
|
||||
"zero_vpn_input_port": {
|
||||
firewall: Firewall{
|
||||
VPNInputPorts: []uint16{0},
|
||||
},
|
||||
errWrapped: ErrFirewallZeroPort,
|
||||
errMessage: "VPN input ports: cannot have a zero port",
|
||||
},
|
||||
"zero_input_port": {
|
||||
firewall: Firewall{
|
||||
InputPorts: []uint16{0},
|
||||
},
|
||||
errWrapped: ErrFirewallZeroPort,
|
||||
errMessage: "input ports: cannot have a zero port",
|
||||
},
|
||||
"unspecified_outbound_subnet": {
|
||||
@@ -36,6 +40,7 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
errWrapped: ErrFirewallPublicOutboundSubnet,
|
||||
errMessage: "outbound subnet has an unspecified address: 0.0.0.0/0",
|
||||
},
|
||||
"public_outbound_subnet": {
|
||||
@@ -65,10 +70,9 @@ func Test_Firewall_validate(t *testing.T) {
|
||||
|
||||
err := testCase.firewall.validate()
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -38,6 +38,12 @@ type Health struct {
|
||||
RestartVPN *bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
|
||||
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
|
||||
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
|
||||
)
|
||||
|
||||
func (h Health) Validate() (err error) {
|
||||
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
||||
if err != nil {
|
||||
@@ -47,16 +53,16 @@ func (h Health) Validate() (err error) {
|
||||
for _, ip := range h.ICMPTargetIPs {
|
||||
switch {
|
||||
case !ip.IsValid():
|
||||
return fmt.Errorf("ICMP target IP address is not valid: %s", ip)
|
||||
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
|
||||
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
|
||||
return errors.New("ICMP target IP addresses are not compatible: " +
|
||||
"only a single IP address must be set if it is to be unspecified")
|
||||
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
|
||||
ErrICMPTargetIPsNotCompatible)
|
||||
}
|
||||
}
|
||||
|
||||
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
|
||||
if err != nil {
|
||||
return fmt.Errorf("small check type is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h HTTPProxy) validate() (err error) {
|
||||
// Do not validate user and password
|
||||
err = validate.ListeningAddress(h.ListeningAddress, os.Getuid())
|
||||
if err != nil {
|
||||
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -176,6 +176,7 @@ func readHTTProxyLog(r *reader.Reader) (enabled *bool, err error) {
|
||||
case "disabled", "no", "off":
|
||||
return ptrTo(false), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: value is unknown: %s", value)
|
||||
return nil, fmt.Errorf("HTTP retro-compatible proxy log setting: %w: %s",
|
||||
ErrValueUnknown, value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package settings
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -93,7 +92,7 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
// Validate version
|
||||
validVersions := []string{openvpn.Openvpn25, openvpn.Openvpn26}
|
||||
if err = validate.IsOneOf(o.Version, validVersions...); err != nil {
|
||||
return fmt.Errorf("version is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrOpenVPNVersionIsNotValid, err)
|
||||
}
|
||||
|
||||
isCustom := vpnProvider == providers.Custom
|
||||
@@ -102,14 +101,14 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
vpnProvider != providers.VPNSecure
|
||||
|
||||
if isUserRequired && *o.User == "" {
|
||||
return errors.New("user is empty")
|
||||
return fmt.Errorf("%w", ErrOpenVPNUserIsEmpty)
|
||||
}
|
||||
|
||||
passwordRequired := isUserRequired &&
|
||||
(vpnProvider != providers.Ivpn || !ivpnAccountID.MatchString(*o.User))
|
||||
|
||||
if passwordRequired && *o.Password == "" {
|
||||
return errors.New("password is empty")
|
||||
return fmt.Errorf("%w", ErrOpenVPNPasswordIsEmpty)
|
||||
}
|
||||
|
||||
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
||||
@@ -133,20 +132,23 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
||||
}
|
||||
|
||||
if *o.EncryptedKey != "" && *o.KeyPassphrase == "" {
|
||||
return errors.New("key passphrase is empty")
|
||||
return fmt.Errorf("%w", ErrOpenVPNKeyPassphraseIsEmpty)
|
||||
}
|
||||
|
||||
const maxMSSFix = 10000
|
||||
if *o.MSSFix > maxMSSFix {
|
||||
return fmt.Errorf("mssfix option value is too high: %d is over the maximum value of %d", *o.MSSFix, maxMSSFix)
|
||||
return fmt.Errorf("%w: %d is over the maximum value of %d",
|
||||
ErrOpenVPNMSSFixIsTooHigh, *o.MSSFix, maxMSSFix)
|
||||
}
|
||||
|
||||
if !regexpInterfaceName.MatchString(o.Interface) {
|
||||
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", o.Interface, regexpInterfaceName)
|
||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||
ErrOpenVPNInterfaceNotValid, o.Interface, regexpInterfaceName)
|
||||
}
|
||||
|
||||
if *o.Verbosity < 0 || *o.Verbosity > 6 {
|
||||
return fmt.Errorf("verbosity value is out of bounds: %d can only be between 0 and 5", o.Verbosity)
|
||||
return fmt.Errorf("%w: %d can only be between 0 and 5",
|
||||
ErrOpenVPNVerbosityIsOutOfBounds, o.Verbosity)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -160,7 +162,7 @@ func validateOpenVPNConfigFilepath(isCustom bool,
|
||||
}
|
||||
|
||||
if confFile == "" {
|
||||
return errors.New("filepath is missing")
|
||||
return fmt.Errorf("%w", ErrFilepathMissing)
|
||||
}
|
||||
|
||||
err = validate.FileExists(confFile)
|
||||
@@ -187,7 +189,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
|
||||
providers.VPNSecure,
|
||||
providers.VPNUnlimited:
|
||||
if clientCert == "" {
|
||||
return errors.New("missing value")
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +211,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
|
||||
providers.Cyberghost,
|
||||
providers.VPNUnlimited:
|
||||
if clientKey == "" {
|
||||
return errors.New("missing value")
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,7 +230,7 @@ func validateOpenVPNEncryptedKey(vpnProvider,
|
||||
encryptedPrivateKey string,
|
||||
) (err error) {
|
||||
if vpnProvider == providers.VPNSecure && encryptedPrivateKey == "" {
|
||||
return errors.New("missing value")
|
||||
return fmt.Errorf("%w", ErrMissingValue)
|
||||
}
|
||||
|
||||
if encryptedPrivateKey == "" {
|
||||
|
||||
@@ -62,7 +62,8 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
providers.Perfectprivacy,
|
||||
providers.Vyprvpn,
|
||||
) {
|
||||
return fmt.Errorf("TCP protocol is not supported: for VPN service provider %s", vpnProvider)
|
||||
return fmt.Errorf("%w: for VPN service provider %s",
|
||||
ErrOpenVPNTCPNotSupported, vpnProvider)
|
||||
}
|
||||
|
||||
// Validate CustomPort
|
||||
@@ -77,7 +78,8 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
providers.Nordvpn, providers.Purevpn,
|
||||
providers.Surfshark, providers.VPNSecure,
|
||||
providers.VPNUnlimited, providers.Vyprvpn:
|
||||
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s", vpnProvider)
|
||||
return fmt.Errorf("%w: for VPN service provider %s",
|
||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider)
|
||||
default:
|
||||
var allowedTCP, allowedUDP []uint16
|
||||
switch vpnProvider {
|
||||
@@ -100,7 +102,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
allowedTCP = []uint16{443, 1194, 8080, 8443}
|
||||
allowedUDP = []uint16{443, 1194, 8080, 8443}
|
||||
case providers.PrivateInternetAccess:
|
||||
allowedTCP = []uint16{80, 110, 443, 501, 502, 8443}
|
||||
allowedTCP = []uint16{80, 110, 443}
|
||||
allowedUDP = []uint16{53, 1194, 1197, 1198, 8080, 9201}
|
||||
case providers.Protonvpn:
|
||||
allowedTCP = []uint16{443, 5995, 8443}
|
||||
@@ -121,7 +123,8 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
}
|
||||
err = validate.IsOneOf(*o.CustomPort, allowedPorts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("custom endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
||||
ErrOpenVPNCustomPortNotAllowed, vpnProvider, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,7 +136,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
||||
presets.Strong,
|
||||
}
|
||||
if err = validate.IsOneOf(*o.PIAEncPreset, validEncryptionPresets...); err != nil {
|
||||
return fmt.Errorf("PIA encryption preset is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrOpenVPNEncryptionPresetNotValid, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -23,16 +24,21 @@ type PMTUD struct {
|
||||
TCPAddresses []netip.AddrPort `json:"tcp_addresses"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPMTUDICMPAddressNotValid = errors.New("PMTUD ICMP address is not valid")
|
||||
ErrPMTUDTCPAddressNotValid = errors.New("PMTUD TCP address is not valid")
|
||||
)
|
||||
|
||||
// Validate validates PMTUD settings.
|
||||
func (p PMTUD) validate() (err error) {
|
||||
for i, addr := range p.ICMPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("PMTUD ICMP address is not valid: at index %d", i)
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDICMPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
for i, addr := range p.TCPAddresses {
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("PMTUD TCP address is not valid: at index %d", i)
|
||||
return fmt.Errorf("%w: at index %d", ErrPMTUDTCPAddressNotValid, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -55,6 +55,12 @@ type PortForwarding struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
||||
)
|
||||
|
||||
func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
if !*p.Enabled {
|
||||
return nil
|
||||
@@ -72,7 +78,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
providers.Protonvpn,
|
||||
}
|
||||
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
||||
return fmt.Errorf("port forwarding cannot be enabled: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
|
||||
}
|
||||
|
||||
// Validate Filepath
|
||||
@@ -88,31 +94,30 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
const maxPortsCount = 1
|
||||
switch {
|
||||
case p.PortsCount > maxPortsCount:
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
case p.Username == "":
|
||||
return errors.New("port forwarding username is empty")
|
||||
return fmt.Errorf("%w", ErrPortForwardingUserEmpty)
|
||||
case p.Password == "":
|
||||
return errors.New("port forwarding password is empty")
|
||||
return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty)
|
||||
}
|
||||
case providers.Protonvpn:
|
||||
const maxPortsCount = 4
|
||||
if p.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
}
|
||||
default:
|
||||
const maxPortsCount = 1
|
||||
if p.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(p.ListeningPorts, []uint16{0}) {
|
||||
switch {
|
||||
case len(p.ListeningPorts) != int(p.PortsCount):
|
||||
return fmt.Errorf("listening ports length must be equal to ports count: "+
|
||||
"%d != %d", len(p.ListeningPorts), p.PortsCount)
|
||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount)
|
||||
case slices.Contains(p.ListeningPorts, 0):
|
||||
return fmt.Errorf("listening port cannot be 0: in %v", p.ListeningPorts)
|
||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
|
||||
}
|
||||
}
|
||||
if err = validate.IsOneOf(p.Name, validNames...); err != nil {
|
||||
return fmt.Errorf("VPN provider name is not valid for %s: %w", vpnType, err)
|
||||
return fmt.Errorf("%w for %s: %w", ErrVPNProviderNameNotValid, vpnType, err)
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(p.Name, filterChoicesGetter, warner)
|
||||
|
||||
@@ -15,6 +15,7 @@ func Test_PublicIP_read(t *testing.T) {
|
||||
makeReader func(ctrl *gomock.Controller) *reader.Reader
|
||||
makeWarner func(ctrl *gomock.Controller) Warner
|
||||
settings PublicIP
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"nothing_read": {
|
||||
@@ -151,10 +152,9 @@ func Test_PublicIP_read(t *testing.T) {
|
||||
err := settings.read(reader, warner)
|
||||
|
||||
assert.Equal(t, testCase.settings, settings)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -46,7 +46,8 @@ func (c ControlServer) validate() (err error) {
|
||||
uid := os.Getuid()
|
||||
const maxPrivilegedPort = 1023
|
||||
if uid != 0 && port != 0 && port <= maxPrivilegedPort {
|
||||
return fmt.Errorf("cannot use privileged port without running as root: %d when running with user ID %d", port, uid)
|
||||
return fmt.Errorf("%w: %d when running with user ID %d",
|
||||
ErrControlServerPrivilegedPort, port, uid)
|
||||
}
|
||||
|
||||
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
|
||||
|
||||
@@ -71,13 +71,25 @@ type ServerSelection struct {
|
||||
Wireguard WireguardSelection `json:"wireguard"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported")
|
||||
ErrFreeOnlyNotSupported = errors.New("free only filter is not supported")
|
||||
ErrPremiumOnlyNotSupported = errors.New("premium only filter is not supported")
|
||||
ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported")
|
||||
ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported")
|
||||
ErrPortForwardOnlyNotSupported = errors.New("port forwarding only filter is not supported")
|
||||
ErrFreePremiumBothSet = errors.New("free only and premium only filters are both set")
|
||||
ErrSecureCoreOnlyNotSupported = errors.New("secure core only filter is not supported")
|
||||
ErrTorOnlyNotSupported = errors.New("tor only filter is not supported")
|
||||
)
|
||||
|
||||
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||
filterChoicesGetter FilterChoicesGetter, warner Warner,
|
||||
) (err error) {
|
||||
switch ss.VPN {
|
||||
case vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard:
|
||||
default:
|
||||
return fmt.Errorf("VPN type is not valid: %s", ss.VPN)
|
||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||
}
|
||||
|
||||
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, filterChoicesGetter, warner)
|
||||
@@ -138,7 +150,7 @@ func getLocationFilterChoices(vpnServiceProvider string,
|
||||
// Only return error comparing with newer regions, we don't want to confuse the user
|
||||
// with the retro regions in the error message.
|
||||
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
|
||||
return models.FilterChoices{}, fmt.Errorf("the region specified is not valid: %w", err)
|
||||
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,27 +164,27 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
) (err error) {
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the country specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the region specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the city specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the ISP specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the hostname specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||
}
|
||||
|
||||
if vpnServiceProvider == providers.Custom {
|
||||
@@ -184,19 +196,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
// which requires a server name for TLS verification.
|
||||
filterChoices.Names = settings.Names
|
||||
default:
|
||||
return fmt.Errorf("name is not valid: "+
|
||||
"%d names specified instead of 0 or 1 for the custom provider",
|
||||
len(settings.Names))
|
||||
return fmt.Errorf("%w: %d names specified instead of "+
|
||||
"0 or 1 for the custom provider",
|
||||
ErrNameNotValid, len(settings.Names))
|
||||
}
|
||||
}
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the server name specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
|
||||
}
|
||||
|
||||
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("the category specified is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -243,12 +255,12 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
||||
switch {
|
||||
case *settings.FreeOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return errors.New("free only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrFreeOnlyNotSupported)
|
||||
case *settings.PremiumOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.VPNSecure):
|
||||
return errors.New("premium only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrPremiumOnlyNotSupported)
|
||||
case *settings.FreeOnly && *settings.PremiumOnly:
|
||||
return errors.New("free only and premium only filters are both set")
|
||||
return fmt.Errorf("%w", ErrFreePremiumBothSet)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -257,21 +269,21 @@ func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvide
|
||||
func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string) error {
|
||||
switch {
|
||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||
return errors.New("owned only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
||||
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||
return errors.New("port forwarding only filter is not supported: together with free only filter")
|
||||
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
||||
case *settings.StreamOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return errors.New("stream only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
||||
case *settings.MultiHopOnly && vpnServiceProvider != providers.Surfshark:
|
||||
return errors.New("multi hop only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrMultiHopOnlyNotSupported)
|
||||
case *settings.PortForwardOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.PrivateInternetAccess, providers.Protonvpn):
|
||||
return errors.New("port forwarding only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrPortForwardOnlyNotSupported)
|
||||
case *settings.SecureCoreOnly && vpnServiceProvider != providers.Protonvpn:
|
||||
return errors.New("secure core only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrSecureCoreOnlyNotSupported)
|
||||
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
|
||||
return errors.New("tor only filter is not supported")
|
||||
return fmt.Errorf("%w", ErrTorOnlyNotSupported)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (s *Settings) SetDefaults() {
|
||||
s.IPv6.setDefaults()
|
||||
s.PublicIP.setDefaults()
|
||||
s.Shadowsocks.setDefaults()
|
||||
s.Storage.SetDefaults()
|
||||
s.Storage.setDefaults()
|
||||
s.System.setDefaults()
|
||||
s.Version.setDefaults()
|
||||
s.VPN.setDefaults()
|
||||
@@ -213,7 +213,7 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) {
|
||||
return s.PublicIP.read(r, warner)
|
||||
},
|
||||
"shadowsocks": s.Shadowsocks.read,
|
||||
"storage": s.Storage.Read,
|
||||
"storage": s.Storage.read,
|
||||
"system": s.System.read,
|
||||
"updater": s.Updater.read,
|
||||
"version": s.Version.read,
|
||||
|
||||
@@ -90,7 +90,7 @@ func Test_Settings_String(t *testing.T) {
|
||||
| ├── Logging: yes
|
||||
| └── Authentication file path: /gluetun/auth/config.toml
|
||||
├── Storage settings:
|
||||
| └── Servers directory path: /gluetun/servers/
|
||||
| └── Filepath: /gluetun/servers.json
|
||||
├── OS Alpine settings:
|
||||
| ├── Process UID: 1000
|
||||
| └── Process GID: 1000
|
||||
|
||||
@@ -11,26 +11,15 @@ import (
|
||||
|
||||
// Storage contains settings to configure the storage.
|
||||
type Storage struct {
|
||||
// ServersEnabled is whether to enable storage of servers on disk.
|
||||
// It defaults to true.
|
||||
ServersEnabled *bool
|
||||
// ServersPath is the path to the servers files directory, and cannot be
|
||||
// the empty string.
|
||||
ServersPath string
|
||||
// LegacyServersFilepath is the legacy "fat" JSON filepath to migrate from.
|
||||
// TODO v4: remove
|
||||
LegacyServersFilepath string
|
||||
// Filepath is the path to the servers.json file. An empty string disables on-disk storage.
|
||||
Filepath *string
|
||||
}
|
||||
|
||||
func (s Storage) validate() (err error) {
|
||||
if *s.ServersEnabled {
|
||||
_, err := filepath.Abs(s.ServersPath)
|
||||
if *s.Filepath != "" { // optional
|
||||
_, err := filepath.Abs(*s.Filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("servers path is not valid: %w", err)
|
||||
}
|
||||
_, err = filepath.Abs(s.LegacyServersFilepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("legacy servers filepath is not valid: %w", err)
|
||||
return fmt.Errorf("filepath is not valid: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -38,25 +27,17 @@ func (s Storage) validate() (err error) {
|
||||
|
||||
func (s *Storage) copy() (copied Storage) {
|
||||
return Storage{
|
||||
ServersEnabled: gosettings.CopyPointer(s.ServersEnabled),
|
||||
ServersPath: s.ServersPath,
|
||||
LegacyServersFilepath: s.LegacyServersFilepath,
|
||||
Filepath: gosettings.CopyPointer(s.Filepath),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Storage) overrideWith(other Storage) {
|
||||
s.ServersEnabled = gosettings.OverrideWithPointer(s.ServersEnabled, other.ServersEnabled)
|
||||
s.ServersPath = gosettings.OverrideWithComparable(s.ServersPath, other.ServersPath)
|
||||
s.LegacyServersFilepath = gosettings.OverrideWithComparable(s.LegacyServersFilepath, other.LegacyServersFilepath)
|
||||
s.Filepath = gosettings.OverrideWithPointer(s.Filepath, other.Filepath)
|
||||
}
|
||||
|
||||
const defaultLegacyServersFilepath = "/gluetun/servers.json"
|
||||
|
||||
func (s *Storage) SetDefaults() {
|
||||
s.ServersEnabled = gosettings.DefaultPointer(s.ServersEnabled, true)
|
||||
const defaultServersPath = "/gluetun/servers/"
|
||||
s.ServersPath = gosettings.DefaultComparable(s.ServersPath, defaultServersPath)
|
||||
s.LegacyServersFilepath = gosettings.DefaultComparable(s.LegacyServersFilepath, defaultLegacyServersFilepath)
|
||||
func (s *Storage) setDefaults() {
|
||||
const defaultFilepath = "/gluetun/servers.json"
|
||||
s.Filepath = gosettings.DefaultPointer(s.Filepath, defaultFilepath)
|
||||
}
|
||||
|
||||
func (s Storage) String() string {
|
||||
@@ -64,33 +45,15 @@ func (s Storage) String() string {
|
||||
}
|
||||
|
||||
func (s Storage) toLinesNode() (node *gotree.Node) {
|
||||
if !*s.ServersEnabled {
|
||||
if *s.Filepath == "" {
|
||||
return gotree.New("Storage settings: disabled")
|
||||
}
|
||||
node = gotree.New("Storage settings:")
|
||||
node.Appendf("Servers directory path: %s", s.ServersPath)
|
||||
if s.LegacyServersFilepath != defaultLegacyServersFilepath {
|
||||
node.Appendf("Legacy servers filepath: %s", s.LegacyServersFilepath)
|
||||
}
|
||||
node.Appendf("Filepath: %s", *s.Filepath)
|
||||
return node
|
||||
}
|
||||
|
||||
func (s *Storage) Read(r *reader.Reader) (err error) {
|
||||
// Retro-compatibility:
|
||||
// TODO v4: remove support for STORAGE_FILEPATH
|
||||
filePath := r.Get("STORAGE_FILEPATH", reader.AcceptEmpty(true), reader.IsRetro("STORAGE_SERVERS_DIRECTORY_PATH"))
|
||||
if filePath != nil {
|
||||
if *filePath == "" {
|
||||
s.ServersEnabled = ptrTo(false)
|
||||
} else {
|
||||
s.LegacyServersFilepath = *filePath
|
||||
}
|
||||
} else {
|
||||
s.ServersEnabled, err = r.BoolPtr("STORAGE_SERVERS_ENABLED")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.ServersPath = r.String("STORAGE_SERVERS_DIRECTORY_PATH")
|
||||
}
|
||||
func (s *Storage) read(r *reader.Reader) (err error) {
|
||||
s.Filepath = r.Get("STORAGE_FILEPATH", reader.AcceptEmpty(true))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -29,9 +28,6 @@ type Updater struct {
|
||||
// Providers is the list of VPN service providers
|
||||
// to update server information for.
|
||||
Providers []string
|
||||
// PreferDirectDownload is whether to prefer direct download of
|
||||
// server data from Github (recommended).
|
||||
PreferDirectDownload *bool
|
||||
// ProtonEmail is the email to authenticate with the Proton API.
|
||||
ProtonEmail *string
|
||||
// ProtonPassword is the password to authenticate with the Proton API.
|
||||
@@ -41,20 +37,20 @@ type Updater struct {
|
||||
func (u Updater) Validate() (err error) {
|
||||
const minPeriod = time.Minute
|
||||
if *u.Period > 0 && *u.Period < minPeriod {
|
||||
return fmt.Errorf("VPN server data updater period is too small: "+
|
||||
"%d must be larger than %s", *u.Period, minPeriod)
|
||||
return fmt.Errorf("%w: %d must be larger than %s",
|
||||
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
|
||||
}
|
||||
|
||||
if u.MinRatio <= 0 || u.MinRatio > 1 {
|
||||
return fmt.Errorf("minimum ratio is not valid: "+
|
||||
"%.2f must be between 0+ and 1", u.MinRatio)
|
||||
return fmt.Errorf("%w: %.2f must be between 0+ and 1",
|
||||
ErrMinRatioNotValid, u.MinRatio)
|
||||
}
|
||||
|
||||
validProviders := providers.All()
|
||||
for _, provider := range u.Providers {
|
||||
err = validate.IsOneOf(provider, validProviders...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("VPN provider name is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
|
||||
}
|
||||
|
||||
if provider == providers.Protonvpn {
|
||||
@@ -62,9 +58,9 @@ func (u Updater) Validate() (err error) {
|
||||
if authenticatedAPI {
|
||||
switch {
|
||||
case *u.ProtonEmail == "":
|
||||
return errors.New("proton email is missing")
|
||||
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
|
||||
case *u.ProtonPassword == "":
|
||||
return errors.New("proton password is missing")
|
||||
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -75,12 +71,11 @@ func (u Updater) Validate() (err error) {
|
||||
|
||||
func (u *Updater) copy() (copied Updater) {
|
||||
return Updater{
|
||||
Period: gosettings.CopyPointer(u.Period),
|
||||
MinRatio: u.MinRatio,
|
||||
Providers: gosettings.CopySlice(u.Providers),
|
||||
PreferDirectDownload: gosettings.CopyPointer(u.PreferDirectDownload),
|
||||
ProtonEmail: gosettings.CopyPointer(u.ProtonEmail),
|
||||
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
|
||||
Period: gosettings.CopyPointer(u.Period),
|
||||
MinRatio: u.MinRatio,
|
||||
Providers: gosettings.CopySlice(u.Providers),
|
||||
ProtonEmail: gosettings.CopyPointer(u.ProtonEmail),
|
||||
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +86,6 @@ func (u *Updater) overrideWith(other Updater) {
|
||||
u.Period = gosettings.OverrideWithPointer(u.Period, other.Period)
|
||||
u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio)
|
||||
u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers)
|
||||
u.PreferDirectDownload = gosettings.OverrideWithPointer(u.PreferDirectDownload, other.PreferDirectDownload)
|
||||
u.ProtonEmail = gosettings.OverrideWithPointer(u.ProtonEmail, other.ProtonEmail)
|
||||
u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword)
|
||||
}
|
||||
@@ -109,7 +103,6 @@ func (u *Updater) SetDefaults(vpnProvider string) {
|
||||
}
|
||||
|
||||
// Set these to empty strings to avoid nil pointer panics
|
||||
u.PreferDirectDownload = gosettings.DefaultPointer(u.PreferDirectDownload, false)
|
||||
u.ProtonEmail = gosettings.DefaultPointer(u.ProtonEmail, "")
|
||||
u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "")
|
||||
}
|
||||
@@ -127,7 +120,6 @@ func (u Updater) toLinesNode() (node *gotree.Node) {
|
||||
node.Appendf("Update period: %s", *u.Period)
|
||||
node.Appendf("Minimum ratio: %.1f", u.MinRatio)
|
||||
node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", "))
|
||||
node.Appendf("Prefer direct download: %s", gosettings.BoolToYesNo(u.PreferDirectDownload))
|
||||
if slices.Contains(u.Providers, providers.Protonvpn) {
|
||||
node.Appendf("Proton API email: %s", *u.ProtonEmail)
|
||||
node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword))
|
||||
@@ -149,11 +141,6 @@ func (u *Updater) read(r *reader.Reader) (err error) {
|
||||
|
||||
u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
|
||||
|
||||
u.PreferDirectDownload, err = r.BoolPtr("UPDATER_PREFER_DIRECT_DOWNLOAD")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.ProtonEmail = r.Get("UPDATER_PROTONVPN_EMAIL")
|
||||
if u.ProtonEmail == nil {
|
||||
protonUsername := r.String("UPDATER_PROTONVPN_USERNAME", reader.IsRetro("UPDATER_PROTONVPN_EMAIL"))
|
||||
|
||||
@@ -37,7 +37,7 @@ func (v *VPN) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Supported bo
|
||||
// Validate Type
|
||||
validVPNTypes := []string{vpn.AmneziaWg, vpn.OpenVPN, vpn.Wireguard}
|
||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||
return fmt.Errorf("VPN type is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
||||
}
|
||||
|
||||
err = v.Provider.validate(v.Type, filterChoicesGetter, warner)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
@@ -55,7 +54,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (err error) {
|
||||
// Validate PrivateKey
|
||||
if *w.PrivateKey == "" {
|
||||
return errors.New("private key is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardPrivateKeyNotSet)
|
||||
}
|
||||
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -69,7 +68,7 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
|
||||
if vpnProvider == providers.Airvpn {
|
||||
if *w.PreSharedKey == "" {
|
||||
return errors.New("pre-shared key is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardPreSharedKeyNotSet)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,15 +82,17 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
|
||||
// Validate Addresses
|
||||
if len(w.Addresses) == 0 {
|
||||
return errors.New("interface address is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardInterfaceAddressNotSet)
|
||||
}
|
||||
for i, ipNet := range w.Addresses {
|
||||
if !ipNet.IsValid() {
|
||||
return fmt.Errorf("interface address is not set: for address at index %d", i)
|
||||
return fmt.Errorf("%w: for address at index %d",
|
||||
ErrWireguardInterfaceAddressNotSet, i)
|
||||
}
|
||||
|
||||
if !ipv6Supported && ipNet.Addr().Is6() {
|
||||
return fmt.Errorf("interface address is IPv6 but IPv6 is not supported: address %s", ipNet.String())
|
||||
return fmt.Errorf("%w: address %s",
|
||||
ErrWireguardInterfaceAddressIPv6, ipNet.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,27 +100,30 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported, amneziawg bool) (
|
||||
// WARNING: do not check for IPv6 networks in the allowed IPs,
|
||||
// the wireguard code will take care to ignore it.
|
||||
if len(w.AllowedIPs) == 0 {
|
||||
return errors.New("allowed IPs is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
|
||||
}
|
||||
for i, allowedIP := range w.AllowedIPs {
|
||||
if !allowedIP.IsValid() {
|
||||
return fmt.Errorf("allowed IP is not set: for allowed ip %d of %d", i+1, len(w.AllowedIPs))
|
||||
return fmt.Errorf("%w: for allowed ip %d of %d",
|
||||
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
|
||||
}
|
||||
}
|
||||
|
||||
if *w.PersistentKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("persistent keep alive interval is negative: %s", *w.PersistentKeepaliveInterval)
|
||||
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
|
||||
*w.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
// Validate interface
|
||||
if !regexpInterfaceName.MatchString(w.Interface) {
|
||||
return fmt.Errorf("interface name is not valid: '%s' does not match regex '%s'", w.Interface, regexpInterfaceName)
|
||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
|
||||
}
|
||||
|
||||
if !amneziawg { // amneziawg should have its own Implementation field and ignore this one
|
||||
validImplementations := []string{"auto", "userspace", "kernelspace"}
|
||||
if err := validate.IsOneOf(w.Implementation, validImplementations...); err != nil {
|
||||
return fmt.Errorf("implementation is not valid: %w", err)
|
||||
return fmt.Errorf("%w: %w", ErrWireguardImplementationNotValid, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,12 +242,10 @@ func (w *Wireguard) read(r *reader.Reader, amneziaWG bool) (err error) {
|
||||
// WARNING: do not initialize w.Addresses to an empty slice
|
||||
// or the defaults for nordvpn will not work.
|
||||
for _, addressString := range addressStrings {
|
||||
addressString = strings.TrimSpace(addressString)
|
||||
if addressString == "" {
|
||||
continue
|
||||
} else if !strings.ContainsRune(addressString, '/') {
|
||||
if !strings.ContainsRune(addressString, '/') {
|
||||
addressString += "/32"
|
||||
}
|
||||
addressString = strings.TrimSpace(addressString)
|
||||
address, err := netip.ParsePrefix(addressString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing address: %w", err)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
@@ -45,7 +44,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// endpoint IP addresses are baked in
|
||||
case providers.Custom:
|
||||
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
|
||||
return errors.New("endpoint IP is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointIPNotSet)
|
||||
}
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
@@ -55,13 +54,13 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// EndpointPort is required
|
||||
case providers.Custom:
|
||||
if *w.EndpointPort == 0 {
|
||||
return errors.New("endpoint port is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
|
||||
}
|
||||
// EndpointPort cannot be set
|
||||
case providers.Fastestvpn, providers.Nordvpn,
|
||||
providers.Protonvpn, providers.Surfshark:
|
||||
if *w.EndpointPort != 0 {
|
||||
return errors.New("endpoint port is set")
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
|
||||
}
|
||||
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
|
||||
// EndpointPort is optional and can be 0
|
||||
@@ -85,7 +84,8 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("endpoint port is not allowed: for VPN service provider %s: %w", vpnProvider, err)
|
||||
return fmt.Errorf("%w: for VPN service provider %s: %w",
|
||||
ErrWireguardEndpointPortNotAllowed, vpnProvider, err)
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
|
||||
@@ -96,14 +96,15 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// public keys are baked in
|
||||
case providers.Custom:
|
||||
if w.PublicKey == "" {
|
||||
return errors.New("public key is not set")
|
||||
return fmt.Errorf("%w", ErrWireguardPublicKeyNotSet)
|
||||
}
|
||||
default: // Providers not supporting Wireguard
|
||||
}
|
||||
if w.PublicKey != "" {
|
||||
_, err := wgtypes.ParseKey(w.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("public key is not valid: %s: %s", w.PublicKey, err)
|
||||
return fmt.Errorf("%w: %s: %s",
|
||||
ErrWireguardPublicKeyNotValid, w.PublicKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
@@ -73,6 +74,8 @@ func parseWireguardInterfaceSection(interfaceSection *ini.Section) (
|
||||
return privateKey, addresses
|
||||
}
|
||||
|
||||
var ErrEndpointHostNotIP = errors.New("endpoint host is not an IP")
|
||||
|
||||
func parseWireguardPeerSection(peerSection *ini.Section) (
|
||||
preSharedKey, publicKey, endpointIP, endpointPort *string,
|
||||
) {
|
||||
@@ -83,7 +86,10 @@ func parseWireguardPeerSection(peerSection *ini.Section) (
|
||||
host, port, err := net.SplitHostPort(*endpoint)
|
||||
if err == nil {
|
||||
endpointIP = &host
|
||||
endpointPort = &port
|
||||
// IPv6 hosts contain colons; port is managed by the provider for those
|
||||
if !strings.Contains(host, ":") {
|
||||
endpointPort = &port
|
||||
}
|
||||
} else {
|
||||
endpointIP = endpoint
|
||||
}
|
||||
|
||||
@@ -182,8 +182,7 @@ Endpoint = 1.2.3.4:51820`,
|
||||
"ipv6_endpoint": {
|
||||
iniData: `[Peer]
|
||||
Endpoint = [2a02:bbbb:aaaa:8075::10]:51820`,
|
||||
endpointIP: ptrTo("2a02:bbbb:aaaa:8075::10"),
|
||||
endpointPort: ptrTo("51820"),
|
||||
endpointIP: ptrTo("2a02:bbbb:aaaa:8075::10"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
// ServersData is the server information filepath.
|
||||
ServersData = "/gluetun/servers.json"
|
||||
)
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
@@ -62,6 +63,8 @@ func generateRandomString(length uint) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var errIPLeakSessionMismatch = errors.New("ipleak.net session mismatch")
|
||||
|
||||
func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
||||
dnsToCount map[string]uint, err error,
|
||||
) {
|
||||
@@ -90,7 +93,7 @@ func triggerDNSQuery(ctx context.Context, client *http.Client, session string) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
} else if data.Session != session {
|
||||
return nil, fmt.Errorf("ipleak.net session mismatch: expected %s, got %s", session, data.Session)
|
||||
return nil, fmt.Errorf("%w: expected %s, got %s", errIPLeakSessionMismatch, session, data.Session)
|
||||
}
|
||||
|
||||
return data.IP, nil
|
||||
|
||||
+12
-22
@@ -33,22 +33,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
for {
|
||||
settings = l.GetSettings()
|
||||
var err error
|
||||
if *settings.ServerEnabled { //nolint:nestif
|
||||
runError, err = l.setupServer(ctx, settings)
|
||||
if err == nil {
|
||||
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
|
||||
err = l.updateFiles(ctx, settings)
|
||||
if err != nil {
|
||||
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
} else {
|
||||
err = l.usePlainServers(settings.UpstreamPlainAddresses)
|
||||
if err == nil {
|
||||
l.logger.Infof("ready and using plain DNS resolvers: %v", settings.UpstreamPlainAddresses)
|
||||
break
|
||||
}
|
||||
runError, err = l.setupServer(ctx, settings)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
l.signalOrSetStatus(constants.Crashed)
|
||||
@@ -59,6 +46,12 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
}
|
||||
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
|
||||
|
||||
err = l.updateFiles(ctx, settings)
|
||||
if err != nil {
|
||||
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
|
||||
}
|
||||
l.signalOrSetStatus(constants.Running)
|
||||
|
||||
l.userTrigger = false
|
||||
@@ -81,13 +74,13 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
l.stopServerIfAny()
|
||||
l.stopServer()
|
||||
// TODO revert OS and Go nameserver when exiting
|
||||
return true
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
l.stopServerIfAny()
|
||||
l.stopServer()
|
||||
l.stopped <- struct{}{}
|
||||
case <-l.start:
|
||||
l.userTrigger = true
|
||||
@@ -101,10 +94,7 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) stopServerIfAny() {
|
||||
if l.server == nil {
|
||||
return
|
||||
}
|
||||
func (l *Loop) stopServer() {
|
||||
stopErr := l.server.Stop()
|
||||
if stopErr != nil {
|
||||
l.logger.Error("stopping server: " + stopErr.Error())
|
||||
|
||||
@@ -3,7 +3,6 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
|
||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||
@@ -46,25 +45,3 @@ func (l *Loop) setupServer(ctx context.Context, settings settings.DNS) (runError
|
||||
|
||||
return runError, nil
|
||||
}
|
||||
|
||||
func (l *Loop) usePlainServers(addrPorts []netip.AddrPort) (err error) {
|
||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
||||
AddrPort: addrPorts[0],
|
||||
})
|
||||
addresses := make([]netip.Addr, len(addrPorts))
|
||||
for i, addrPort := range addrPorts {
|
||||
const defaultDNSPort = 53
|
||||
if addrPort.Port() != defaultDNSPort {
|
||||
return fmt.Errorf("invalid DNS port: %d, must be %d", addrPort.Port(), defaultDNSPort)
|
||||
}
|
||||
addresses[i] = addrPort.Addr()
|
||||
}
|
||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||
IPs: addresses,
|
||||
ResolvPath: l.resolvConf,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("using DNS system wide: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
@@ -42,7 +41,8 @@ func (c *Config) saveAndRestore(ctx context.Context) (restore func(context.Conte
|
||||
// Callers of saveAndRestoreIPv4 MUST always lock the [Config] iptablesMutex
|
||||
// before calling this function.
|
||||
func (c *Config) saveAndRestoreIPv4(ctx context.Context) (restore func(context.Context), err error) {
|
||||
data, err := saveData(ctx, c.ipTables)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv4 iptables: %w", err)
|
||||
}
|
||||
@@ -65,13 +65,14 @@ func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.C
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
data, err := saveData(ctx, c.ip6Tables)
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables+"-save") //nolint:gosec
|
||||
data, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving IPv6 iptables: %w", err)
|
||||
}
|
||||
|
||||
restore = func(ctx context.Context) {
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
|
||||
cmd = exec.CommandContext(ctx, c.ip6Tables+"-restore") //nolint:gosec
|
||||
cmd.Stdin = strings.NewReader(data)
|
||||
output, err := c.runner.Run(cmd)
|
||||
if err != nil {
|
||||
@@ -84,38 +85,3 @@ func (c *Config) saveAndRestoreIPv6(ctx context.Context) (restore func(context.C
|
||||
func makeRestoreErrorMessage(err error, output, data string) string {
|
||||
return fmt.Sprintf("%s: %s: restoring from data:\n%s", err, output, data)
|
||||
}
|
||||
|
||||
func saveData(ctx context.Context, binary string) (data string, err error) {
|
||||
cmd := exec.CommandContext(ctx, binary+"-save") //nolint:gosec
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
stderr := strings.TrimSuffix(string(exitErr.Stderr), "\n")
|
||||
if stderr != "" {
|
||||
return "", fmt.Errorf("running %s-save: %w: %s", binary, err, stderr)
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("running %s-save: %w", binary, err)
|
||||
}
|
||||
err = checkData(string(output))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking saved data: %w", err)
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func checkData(data string) error {
|
||||
scanner := bufio.NewScanner(strings.NewReader(data))
|
||||
i := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "[unsupported") {
|
||||
return fmt.Errorf("unsupported revision marker found in line %d: %s", i+1, line)
|
||||
}
|
||||
i++
|
||||
}
|
||||
if scanner.Err() != nil {
|
||||
return fmt.Errorf("scanning data: %w", scanner.Err())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,15 +57,18 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockCmdRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: parsing \"invalid\": " +
|
||||
"iptables command is malformed: flag \"invalid\" requires a value, but got none",
|
||||
},
|
||||
@@ -75,7 +78,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
runner := NewMockCmdRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errors.New("test error"))
|
||||
Return("", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -83,6 +86,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
@@ -116,7 +120,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errors.New("test error"))
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
@@ -127,6 +131,7 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
@@ -172,10 +177,9 @@ func Test_deleteIPTablesRule(t *testing.T) {
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,11 +82,13 @@ func (c *Config) runIP6tablesInstructionNoSave(ctx context.Context, instruction
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) SetIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("policy is not valid: %s", policy)
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
|
||||
@@ -2,6 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@@ -12,8 +13,10 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
needIP6Tables = "ip6tables is required, please upgrade your kernel"
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
@@ -33,7 +36,7 @@ func (c *Config) Version(ctx context.Context) (string, error) {
|
||||
words := strings.Fields(output)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
return "", fmt.Errorf("iptables version string is too short: %s", output)
|
||||
return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output)
|
||||
}
|
||||
return "iptables " + words[1], nil
|
||||
}
|
||||
@@ -99,7 +102,7 @@ func (c *Config) SetIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("unknown policy: %s", policy)
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
@@ -126,7 +129,7 @@ func (c *Config) AcceptInputToSubnet(ctx context.Context, intf string, destinati
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
}
|
||||
if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept input to subnet %s: %s", destination, needIP6Tables)
|
||||
return fmt.Errorf("accept input to subnet %s: %w", destination, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -154,7 +157,7 @@ func (c *Config) AcceptOutputTrafficToVPN(ctx context.Context,
|
||||
if connection.IP.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", connection.IP, needIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -172,7 +175,7 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
if ip.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server %s: %s", ip, needIP6Tables)
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -197,7 +200,7 @@ func (c *Config) AcceptOutputFromIPToSubnet(ctx context.Context,
|
||||
if doIPv4 {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %s", sourceIP, destinationSubnet, needIP6Tables)
|
||||
return fmt.Errorf("accept output from %s to %s: %w", sourceIP, destinationSubnet, ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
@@ -347,7 +350,7 @@ func (c *Config) RunUserPostRules(ctx context.Context, filepath string) error {
|
||||
case ipv4:
|
||||
err = c.runIptablesInstructionNoSave(ctx, rule)
|
||||
case c.ip6Tables == "":
|
||||
err = fmt.Errorf("running user ip6tables rule: %s", needIP6Tables)
|
||||
err = fmt.Errorf("running user ip6tables rule: %w", ErrNeedIP6Tables)
|
||||
default: // ipv6
|
||||
err = c.runIP6tablesInstructionNoSave(ctx, rule)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ type mark struct {
|
||||
value uint
|
||||
}
|
||||
|
||||
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
@@ -61,8 +63,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: not enough lines to process in: %s",
|
||||
iptablesOutput)
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
@@ -75,8 +77,8 @@ func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: legend %q is not the expected %q",
|
||||
legendLine, strings.Join(expectedLegendFields, " "))
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
@@ -109,8 +111,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %d fields in %q",
|
||||
expectedNumberOfFields, line)
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
@@ -124,8 +126,8 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("iptables chain list output is malformed: expected %q for field %d in %q",
|
||||
expectedValue, index, line)
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
@@ -150,17 +152,19 @@ func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, errors.New("chain rule is malformed: empty line")
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, errors.New("chain rule is malformed: not enough fields")
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
@@ -182,7 +186,7 @@ func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("chain rule is malformed: empty field at index %d", fieldIndex)
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -274,8 +278,8 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.redirPorts = ports
|
||||
i++
|
||||
default:
|
||||
return fmt.Errorf("chain rule is malformed: unexpected %q after redir",
|
||||
optionalFields[1])
|
||||
return fmt.Errorf("%w: unexpected %q after redir",
|
||||
ErrChainRuleMalformed, optionalFields[1])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
@@ -290,13 +294,15 @@ func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err
|
||||
rule.mark = mark
|
||||
i += consumed
|
||||
default:
|
||||
return fmt.Errorf("chain rule is malformed: unexpected optional field: %s",
|
||||
optionalFields[i])
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUDPOptionalUnknown = errors.New("unknown UDP optional field")
|
||||
|
||||
func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -317,12 +323,14 @@ func parseUDPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown UDP optional field: %s", value)
|
||||
return 0, fmt.Errorf("%w: %s", errUDPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
var errTCPOptionalUnknown = errors.New("unknown TCP optional field")
|
||||
|
||||
func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, err error) {
|
||||
for _, value := range optionalFields {
|
||||
if !strings.ContainsRune(value, ':') {
|
||||
@@ -349,7 +357,7 @@ func parseTCPOptional(optionalFields []string, rule *chainRule) (consumed int, e
|
||||
}
|
||||
consumed++
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown TCP optional field: %s", value)
|
||||
return 0, fmt.Errorf("%w: %s", errTCPOptionalUnknown, value)
|
||||
}
|
||||
}
|
||||
return consumed, nil
|
||||
@@ -365,13 +373,15 @@ func parseSourcePort(value string) (port uint16, err error) {
|
||||
return parsePort(value)
|
||||
}
|
||||
|
||||
var errTCPFlagsMalformed = errors.New("TCP flags are malformed")
|
||||
|
||||
func parseTCPFlags(value string) (tcpFlags, error) {
|
||||
value = strings.TrimPrefix(value, "flags:")
|
||||
fields := strings.Split(value, "/")
|
||||
const expectedFields = 2
|
||||
if len(fields) != expectedFields {
|
||||
return tcpFlags{}, fmt.Errorf("TCP flags are malformed: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
value)
|
||||
return tcpFlags{}, fmt.Errorf("%w: expected format 'flags:<mask>/<comparison>' in %q",
|
||||
errTCPFlagsMalformed, value)
|
||||
}
|
||||
maskFlags := strings.Split(fields[0], ",")
|
||||
mask := make([]tcpFlag, len(maskFlags))
|
||||
@@ -412,6 +422,8 @@ func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var errMarkValueMalformed = errors.New("mark value is malformed")
|
||||
|
||||
func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
switch optionalFields[consumed] {
|
||||
case "match":
|
||||
@@ -425,36 +437,42 @@ func parseMark(optionalFields []string) (m mark, consumed int, err error) {
|
||||
const bits = 32
|
||||
value, err := strconv.ParseUint(optionalFields[consumed], base, bits)
|
||||
if err != nil {
|
||||
return mark{}, 0, fmt.Errorf("mark value is malformed: %s", optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("%w: %s", errMarkValueMalformed, optionalFields[consumed])
|
||||
}
|
||||
m.value = uint(value)
|
||||
consumed++
|
||||
default:
|
||||
return mark{}, 0, fmt.Errorf("chain rule is malformed: unexpected mark mode field: %s",
|
||||
optionalFields[consumed])
|
||||
return mark{}, 0, fmt.Errorf("%w: unexpected mark mode field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[consumed])
|
||||
}
|
||||
return m, consumed, nil
|
||||
}
|
||||
|
||||
var ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, errors.New("line number is zero")
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var ErrTargetUnknown = errors.New("unknown target")
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown target: %s", target)
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0", "all":
|
||||
@@ -465,16 +483,18 @@ func parseProtocol(s string) (protocol string, err error) {
|
||||
case "17", "udp":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("unknown protocol: %s", s)
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, errors.New("metric size is malformed: empty string")
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:mnd
|
||||
@@ -496,7 +516,7 @@ func parseMetricSize(size string) (n uint64, err error) {
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("metric size is malformed: %w", err)
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
|
||||
@@ -13,25 +13,30 @@ func Test_parseChain(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
@@ -130,10 +135,9 @@ num pkts bytes target prot opt in out source destinati
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,9 +80,11 @@ func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, errors.New("iptables command is malformed: empty instruction")
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
|
||||
@@ -171,7 +173,7 @@ func parseInstructionFlag(fields []string, instruction *iptablesInstruction) (co
|
||||
return 0, fmt.Errorf("parsing TCP flags: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown key %q", flag)
|
||||
return 0, fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
@@ -183,15 +185,15 @@ func preCheckInstructionFields(fields []string) (consumed int, err error) {
|
||||
case "--tcp-flags": // -m can have 1 or 2 values
|
||||
const expected = 3
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires at least 2 values, but got %s",
|
||||
flag, strings.Join(fields, " "))
|
||||
return 0, fmt.Errorf("%w: flag %q requires at least 2 values, but got %s",
|
||||
ErrIptablesCommandMalformed, flag, strings.Join(fields, " "))
|
||||
}
|
||||
return expected, nil
|
||||
default:
|
||||
const expected = 2
|
||||
if len(fields) < expected {
|
||||
return 0, fmt.Errorf("iptables command is malformed: flag %q requires a value, but got none",
|
||||
flag)
|
||||
return 0, fmt.Errorf("%w: flag %q requires a value, but got none",
|
||||
ErrIptablesCommandMalformed, flag)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
@@ -237,12 +239,12 @@ func parseMatchModule(fields []string, instruction *iptablesInstruction) (
|
||||
consumed++
|
||||
instruction.mark.invert = true
|
||||
default:
|
||||
return consumed, fmt.Errorf("iptables command is malformed: unsupported match mark with value: %s",
|
||||
fields[2])
|
||||
return consumed, fmt.Errorf("%w: unsupported match mark with value: %s",
|
||||
ErrIptablesCommandMalformed, fields[2])
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("iptables command is malformed: unknown match value: %s",
|
||||
fields[consumed])
|
||||
return 0, fmt.Errorf("%w: unknown match value: %s",
|
||||
ErrIptablesCommandMalformed, fields[consumed])
|
||||
}
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
@@ -13,17 +13,21 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-A\": iptables command is malformed: flag \"-A\" requires a value, but got none",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
@@ -70,10 +74,9 @@ func Test_parseIptablesInstruction(t *testing.T) {
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,7 +10,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrNotSupported = errors.New("no iptables supported found")
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
iptablesPathsToTry ...string,
|
||||
@@ -48,7 +53,7 @@ func checkIptablesSupport(ctx context.Context, runner CmdRunner,
|
||||
if allArePermissionDenied {
|
||||
// If the error is related to a denied permission for all iptables path,
|
||||
// return an error describing what to do from an end-user perspective.
|
||||
return "", fmt.Errorf("NET_ADMIN capability is missing: %s", strings.Join(allUnsupportedMessages, "; "))
|
||||
return "", fmt.Errorf("%w: %s", ErrNetAdminMissing, strings.Join(allUnsupportedMessages, "; "))
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%w: errors encountered are: %s",
|
||||
@@ -80,7 +85,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
// this is a critical error, we want to make sure our test rule gets removed.
|
||||
criticalErr = fmt.Errorf("failed cleaning up test rule: %s (%s)", output, err)
|
||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
@@ -103,7 +108,7 @@ func testIptablesPath(ctx context.Context, path string,
|
||||
}
|
||||
|
||||
if inputPolicy == "" {
|
||||
criticalErr = fmt.Errorf("input policy not found: in INPUT rules: %s", output)
|
||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
@@ -42,6 +43,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
iptablesPathsToTry []string
|
||||
iptablesPath string
|
||||
errSentinel error
|
||||
errMessage string
|
||||
}{
|
||||
"critical error when checking": {
|
||||
@@ -54,6 +56,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrTestRuleCleanup,
|
||||
errMessage: "for path1: failed cleaning up test rule: " +
|
||||
"output (exit code 4)",
|
||||
},
|
||||
@@ -83,6 +86,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNetAdminMissing,
|
||||
errMessage: "NET_ADMIN capability is missing: " +
|
||||
"path1: Permission denied (you must be root) more context (exit code 4); " +
|
||||
"path2: context: Permission denied (you must be root) (exit code 4)",
|
||||
@@ -97,6 +101,7 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
return runner
|
||||
},
|
||||
iptablesPathsToTry: []string{"path1", "path2"},
|
||||
errSentinel: ErrNotSupported,
|
||||
errMessage: "no iptables supported found: " +
|
||||
"errors encountered are: " +
|
||||
"path1: output 1 (exit code 4); " +
|
||||
@@ -113,10 +118,9 @@ func Test_checkIptablesSupport(t *testing.T) {
|
||||
|
||||
iptablesPath, err := checkIptablesSupport(ctx, runner, testCase.iptablesPathsToTry...)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
require.ErrorIs(t, err, testCase.errSentinel)
|
||||
if testCase.errSentinel != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, testCase.iptablesPath, iptablesPath)
|
||||
})
|
||||
@@ -135,6 +139,7 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
buildRunner func(ctrl *gomock.Controller) CmdRunner
|
||||
ok bool
|
||||
unsupportedMessage string
|
||||
criticalErrWrapped error
|
||||
criticalErrMessage string
|
||||
}{
|
||||
"append test rule permission denied": {
|
||||
@@ -163,6 +168,7 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrTestRuleCleanup,
|
||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||
},
|
||||
"list input rules permission denied": {
|
||||
@@ -196,6 +202,7 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
Return("some\noutput", nil)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||
},
|
||||
"set policy permission denied": {
|
||||
@@ -250,10 +257,9 @@ func Test_testIptablesPath(t *testing.T) {
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||
if testCase.criticalErrMessage != "" {
|
||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
||||
if testCase.criticalErrWrapped != nil {
|
||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||
} else {
|
||||
assert.NoError(t, criticalErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,10 +45,12 @@ func (f tcpFlag) String() string {
|
||||
case tcpFlagCWR:
|
||||
return "CWR"
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown TCP flag: %d", f))
|
||||
panic(fmt.Sprintf("%s: %d", errTCPFlagUnknown, f))
|
||||
}
|
||||
}
|
||||
|
||||
var errTCPFlagUnknown = errors.New("unknown TCP flag")
|
||||
|
||||
func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
allFlags := []tcpFlag{
|
||||
tcpFlagFIN, tcpFlagSYN, tcpFlagRST, tcpFlagPSH,
|
||||
@@ -59,7 +61,7 @@ func parseTCPFlag(s string) (tcpFlag, error) {
|
||||
return flag, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("unknown TCP flag: %s", s)
|
||||
return 0, fmt.Errorf("%w: %s", errTCPFlagUnknown, s)
|
||||
}
|
||||
|
||||
var ErrMarkMatchModuleMissing = errors.New("kernel is missing the mark module libxt_mark.so")
|
||||
|
||||
@@ -266,6 +266,8 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
|
||||
return address, nil
|
||||
}
|
||||
|
||||
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
||||
|
||||
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||
logger Logger, checkName string, check func(ctx context.Context, try int) error,
|
||||
) error {
|
||||
@@ -295,7 +297,7 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||
for i, err := range errs {
|
||||
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
|
||||
}
|
||||
return fmt.Errorf("all check tries failed:\n\t%s", strings.Join(errStrings, "\n\t"))
|
||||
return fmt.Errorf("%w:\n\t%s", ErrAllCheckTriesFailed, strings.Join(errStrings, "\n\t"))
|
||||
}
|
||||
|
||||
func (c *Checker) startupCheck(ctx context.Context) error {
|
||||
@@ -340,7 +342,7 @@ func (c *Checker) startupCheck(ctx context.Context) error {
|
||||
for i, err := range errs {
|
||||
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
|
||||
}
|
||||
return fmt.Errorf("all check tries failed: %s", strings.Join(errStrings, ", "))
|
||||
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -2,6 +2,7 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -67,7 +68,7 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
address string
|
||||
addressToDial string
|
||||
errMessage string
|
||||
err error
|
||||
}{
|
||||
"host without port": {
|
||||
address: "test.com",
|
||||
@@ -78,8 +79,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
addressToDial: "test.com:80",
|
||||
},
|
||||
"bad address": {
|
||||
address: "test.com::",
|
||||
errMessage: "splitting host and port from address: address test.com::: too many colons in address",
|
||||
address: "test.com::",
|
||||
err: fmt.Errorf("splitting host and port from address: address test.com::: too many colons in address"), //nolint:lll
|
||||
},
|
||||
}
|
||||
|
||||
@@ -90,8 +91,8 @@ func Test_makeAddressToDial(t *testing.T) {
|
||||
addressToDial, err := makeAddressToDial(testCase.address)
|
||||
|
||||
assert.Equal(t, testCase.addressToDial, addressToDial)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
if testCase.err != nil {
|
||||
assert.EqualError(t, err, testCase.err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,12 +2,15 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrHTTPStatusNotOK = errors.New("HTTP response status is not OK")
|
||||
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
@@ -38,6 +41,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("HTTP response status is not OK: %d %s: %s",
|
||||
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
|
||||
response.StatusCode, response.Status, string(b))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -40,6 +41,8 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
|
||||
return result
|
||||
}
|
||||
|
||||
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
|
||||
|
||||
func (c *Client) Check(ctx context.Context) error {
|
||||
dnsAddr := c.serverAddrs[c.dnsIPIndex].String()
|
||||
resolver := &net.Resolver{
|
||||
@@ -56,7 +59,7 @@ func (c *Client) Check(ctx context.Context) error {
|
||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
|
||||
case len(ips) == 0:
|
||||
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
||||
return fmt.Errorf("with DNS server %s: no IPs found from DNS lookup", dnsAddr)
|
||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,9 +12,11 @@ type handler struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
|
||||
|
||||
func newHandler(logger Logger) *handler {
|
||||
return &handler{
|
||||
healthErr: errors.New("healthcheck did not run yet"),
|
||||
healthErr: errHealthcheckNotRunYet,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
)
|
||||
|
||||
type Echoer struct {
|
||||
buffer []byte
|
||||
randomSource io.Reader
|
||||
@@ -55,7 +60,10 @@ func (e *Echoer) Reset() {
|
||||
e.seqStart = time.Now()
|
||||
}
|
||||
|
||||
var ErrNotPermitted = errors.New("not permitted")
|
||||
var (
|
||||
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
|
||||
ErrNotPermitted = errors.New("not permitted")
|
||||
)
|
||||
|
||||
func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||
var ipVersion string
|
||||
@@ -106,14 +114,14 @@ func (e *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||
receivedData, err := receiveEchoReply(conn, e.id, e.seq, e.buffer, ipVersion, e.logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
||||
return fmt.Errorf("timed out waiting for ICMP echo reply from %s", ip)
|
||||
return fmt.Errorf("%w from %s", ErrTimedOut, ip)
|
||||
}
|
||||
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
|
||||
}
|
||||
|
||||
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
||||
if !bytes.Equal(receivedData, sentData) {
|
||||
return fmt.Errorf("ICMP data mismatch: sent %x to %s and received %x", sentData, ip, receivedData)
|
||||
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -208,9 +216,8 @@ func receiveEchoReply(conn net.PacketConn, id, seq int, buffer []byte, ipVersion
|
||||
message.Code, returnAddr, id, seq)
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("ICMP body type is not supported: "+
|
||||
"%T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
||||
body, message.Type, message.Code, returnAddr, id, seq)
|
||||
return nil, fmt.Errorf("%w: %T (type %d, code %d, return address %s, expected id %d and seq %d)",
|
||||
ErrICMPBodyUnsupported, body, message.Type, message.Code, returnAddr, id, seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger
|
||||
@@ -19,9 +20,11 @@ func Test_New(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
expected *Server
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty settings": {
|
||||
errWrapped: ErrHandlerIsNotSet,
|
||||
errMessage: "http server settings validation failed: HTTP handler cannot be left unset",
|
||||
},
|
||||
"filled settings": {
|
||||
@@ -49,10 +52,9 @@ func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, err := New(testCase.settings)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
require.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
|
||||
@@ -64,6 +64,14 @@ func (s *Settings) OverrideWith(other Settings) {
|
||||
s.ShutdownTimeout = gosettings.OverrideWithComparable(s.ShutdownTimeout, other.ShutdownTimeout)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset")
|
||||
ErrLoggerIsNotSet = errors.New("logger cannot be left unset")
|
||||
ErrReadHeaderTimeoutTooSmall = errors.New("read header timeout is too small")
|
||||
ErrReadTimeoutTooSmall = errors.New("read timeout is too small")
|
||||
ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small")
|
||||
)
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
err = validate.ListeningAddress(s.Address, os.Getuid())
|
||||
if err != nil {
|
||||
@@ -71,25 +79,31 @@ func (s Settings) Validate() (err error) {
|
||||
}
|
||||
|
||||
if s.Handler == nil {
|
||||
return errors.New("HTTP handler cannot be left unset")
|
||||
return fmt.Errorf("%w", ErrHandlerIsNotSet)
|
||||
}
|
||||
|
||||
if s.Logger == nil {
|
||||
return errors.New("logger cannot be left unset")
|
||||
return fmt.Errorf("%w", ErrLoggerIsNotSet)
|
||||
}
|
||||
|
||||
const minReadTimeout = time.Millisecond
|
||||
if s.ReadHeaderTimeout < minReadTimeout {
|
||||
return fmt.Errorf("read header timeout is too small: %s must be at least %s", s.ReadHeaderTimeout, minReadTimeout)
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrReadHeaderTimeoutTooSmall,
|
||||
s.ReadHeaderTimeout, minReadTimeout)
|
||||
}
|
||||
|
||||
if s.ReadTimeout < minReadTimeout {
|
||||
return fmt.Errorf("read timeout is too small: %s must be at least %s", s.ReadTimeout, minReadTimeout)
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrReadTimeoutTooSmall,
|
||||
s.ReadTimeout, minReadTimeout)
|
||||
}
|
||||
|
||||
const minShutdownTimeout = 5 * time.Millisecond
|
||||
if s.ShutdownTimeout < minShutdownTimeout {
|
||||
return fmt.Errorf("shutdown timeout is too small: %s must be at least %s", s.ShutdownTimeout, minShutdownTimeout)
|
||||
return fmt.Errorf("%w: %s must be at least %s",
|
||||
ErrShutdownTimeoutTooSmall,
|
||||
s.ShutdownTimeout, minShutdownTimeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -188,26 +189,30 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"bad_address": {
|
||||
settings: Settings{
|
||||
Address: "address:notanint",
|
||||
},
|
||||
errWrapped: validate.ErrPortNotAnInteger,
|
||||
errMessage: "port value is not an integer: notanint",
|
||||
},
|
||||
"nil handler": {
|
||||
settings: Settings{
|
||||
Address: ":8000",
|
||||
},
|
||||
errMessage: "HTTP handler cannot be left unset",
|
||||
errWrapped: ErrHandlerIsNotSet,
|
||||
errMessage: ErrHandlerIsNotSet.Error(),
|
||||
},
|
||||
"nil logger": {
|
||||
settings: Settings{
|
||||
Address: ":8000",
|
||||
Handler: someHandler,
|
||||
},
|
||||
errMessage: "logger cannot be left unset",
|
||||
errWrapped: ErrLoggerIsNotSet,
|
||||
errMessage: ErrLoggerIsNotSet.Error(),
|
||||
},
|
||||
"read header timeout too small": {
|
||||
settings: Settings{
|
||||
@@ -216,6 +221,7 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
Logger: someLogger,
|
||||
ReadHeaderTimeout: time.Nanosecond,
|
||||
},
|
||||
errWrapped: ErrReadHeaderTimeoutTooSmall,
|
||||
errMessage: "read header timeout is too small: 1ns must be at least 1ms",
|
||||
},
|
||||
"read timeout too small": {
|
||||
@@ -226,6 +232,7 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
ReadHeaderTimeout: time.Millisecond,
|
||||
ReadTimeout: time.Nanosecond,
|
||||
},
|
||||
errWrapped: ErrReadTimeoutTooSmall,
|
||||
errMessage: "read timeout is too small: 1ns must be at least 1ms",
|
||||
},
|
||||
"shutdown timeout too small": {
|
||||
@@ -237,6 +244,7 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
ReadTimeout: time.Millisecond,
|
||||
ShutdownTimeout: time.Millisecond,
|
||||
},
|
||||
errWrapped: ErrShutdownTimeoutTooSmall,
|
||||
errMessage: "shutdown timeout is too small: 1ms must be at least 5ms",
|
||||
},
|
||||
"valid settings": {
|
||||
@@ -257,10 +265,9 @@ func Test_Settings_Validate(t *testing.T) {
|
||||
|
||||
err := testCase.settings.Validate()
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,12 +2,15 @@ package loopstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var ErrInvalidStatus = errors.New("invalid status")
|
||||
|
||||
// ApplyStatus sends signals to the running loop depending on the
|
||||
// current status and status requested, such that its next status
|
||||
// matches the requested one. It is thread safe and a synchronous call
|
||||
@@ -70,7 +73,7 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
|
||||
return newStatus.String(), nil
|
||||
default:
|
||||
s.statusMu.Unlock()
|
||||
return "", fmt.Errorf("invalid status: %s: it can only be one of: %s, %s",
|
||||
status, constants.Running, constants.Stopped)
|
||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||
ErrInvalidStatus, status, constants.Running, constants.Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,19 @@ package mod
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errModuleNameUnknown = errors.New("unknown module name")
|
||||
errKernelFeatureIsModule = errors.New("kernel feature is a module, not built-in")
|
||||
errKernelFeatureNotSet = errors.New("kernel feature not set")
|
||||
errKernelFeatureNotFound = errors.New("kernel feature not found")
|
||||
)
|
||||
|
||||
// checkProcConfig checks /proc/config.gz for a the kernel feature corresponding
|
||||
// to the given module name. If the kernel feature is found and set to "y", it returns nil.
|
||||
// If the kernel feature is found and set to "m", it returns an error indicating that the kernel
|
||||
@@ -31,7 +39,7 @@ func checkProcConfig(moduleName string) error {
|
||||
// If any group of kernel features is satisfied, then the module is considered supported.
|
||||
kernelFeatureGroups, ok := moduleNameToKernelFeatureGroups(moduleName)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown module name: %s", moduleName)
|
||||
return fmt.Errorf("%w: %s", errModuleNameUnknown, moduleName)
|
||||
}
|
||||
groups := make([]map[string]bool, len(kernelFeatureGroups))
|
||||
for i, group := range kernelFeatureGroups {
|
||||
@@ -50,20 +58,20 @@ func checkProcConfig(moduleName string) error {
|
||||
switch {
|
||||
case ok:
|
||||
case strings.HasPrefix(line, name+"=m"):
|
||||
return fmt.Errorf("kernel feature is a module, not built-in: %s", name)
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureIsModule, name)
|
||||
case strings.HasPrefix(line, name+"=y"):
|
||||
featureToOK[name] = true
|
||||
if allFeaturesOK(featureToOK) {
|
||||
return nil
|
||||
}
|
||||
case strings.HasPrefix(line, "# "+name+" is not set"):
|
||||
return fmt.Errorf("kernel feature not set: %s", name)
|
||||
return fmt.Errorf("%w: %s", errKernelFeatureNotSet, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("kernel feature not found: for module name %s", moduleName)
|
||||
return fmt.Errorf("%w: for module name %s", errKernelFeatureNotFound, moduleName)
|
||||
}
|
||||
|
||||
func moduleNameToKernelFeatureGroups(moduleName string) (featureGroups [][]string, ok bool) {
|
||||
|
||||
@@ -181,6 +181,8 @@ func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrModulePathNotFound = errors.New("module path not found")
|
||||
|
||||
func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) {
|
||||
// Kernel module names can have underscores or hyphens in their names,
|
||||
// but only one or the other in one particular name.
|
||||
@@ -203,5 +205,5 @@ func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modul
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("module path not found: for %q", moduleName)
|
||||
return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mod
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -13,10 +14,15 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModuleInfoNotFound = errors.New("module info not found")
|
||||
ErrCircularDependency = errors.New("circular dependency")
|
||||
)
|
||||
|
||||
func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) {
|
||||
info, ok := modulesInfo[path]
|
||||
if !ok {
|
||||
return fmt.Errorf("module info not found: %s", path)
|
||||
return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path)
|
||||
}
|
||||
|
||||
switch info.state {
|
||||
@@ -24,7 +30,8 @@ func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error
|
||||
case loaded, builtin:
|
||||
return nil
|
||||
case loading:
|
||||
return fmt.Errorf("circular dependency: %s is already in the loading state", path)
|
||||
return fmt.Errorf("%w: %s is already in the loading state",
|
||||
ErrCircularDependency, path)
|
||||
}
|
||||
|
||||
info.state = loading
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -108,6 +109,8 @@ func (s *Servers) toMarkdown(vpnProvider string) (formatted string, err error) {
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
var ErrMarkdownHeadersNotDefined = errors.New("markdown headers not defined")
|
||||
|
||||
func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
||||
switch vpnProvider {
|
||||
case providers.Airvpn:
|
||||
@@ -166,6 +169,6 @@ func getMarkdownHeaders(vpnProvider string) (headers []string, err error) {
|
||||
case providers.Windscribe:
|
||||
return []string{regionHeader, cityHeader, hostnameHeader, vpnHeader}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("markdown headers not defined: for %s", vpnProvider)
|
||||
return nil, fmt.Errorf("%w: for %s", ErrMarkdownHeadersNotDefined, vpnProvider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,12 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
||||
provider string
|
||||
servers Servers
|
||||
formatted string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"unsupported_provider": {
|
||||
provider: "unsupported",
|
||||
errWrapped: ErrMarkdownHeadersNotDefined,
|
||||
errMessage: "getting markdown headers: markdown headers not defined: for unsupported",
|
||||
},
|
||||
providers.Cyberghost: {
|
||||
@@ -56,10 +58,9 @@ func Test_Servers_ToMarkdown(t *testing.T) {
|
||||
markdown, err := testCase.servers.toMarkdown(testCase.provider)
|
||||
|
||||
assert.Equal(t, testCase.formatted, markdown)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -38,18 +38,27 @@ type Server struct {
|
||||
IPs []netip.Addr `json:"ips,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNFieldEmpty = errors.New("vpn field is empty")
|
||||
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
|
||||
ErrIPsFieldEmpty = errors.New("ips field is empty")
|
||||
ErrNoNetworkProtocol = errors.New("both TCP and UDP fields are false for OpenVPN")
|
||||
ErrNetworkProtocolSet = errors.New("no network protocol should be set")
|
||||
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
|
||||
)
|
||||
|
||||
func (s *Server) HasMinimumInformation() (err error) {
|
||||
switch {
|
||||
case s.VPN == "":
|
||||
return errors.New("vpn field is empty")
|
||||
return fmt.Errorf("%w", ErrVPNFieldEmpty)
|
||||
case len(s.IPs) == 0:
|
||||
return errors.New("ips field is empty")
|
||||
return fmt.Errorf("%w", ErrIPsFieldEmpty)
|
||||
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
|
||||
return errors.New("no network protocol should be set")
|
||||
return fmt.Errorf("%w", ErrNetworkProtocolSet)
|
||||
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
|
||||
return errors.New("both TCP and UDP fields are false for OpenVPN")
|
||||
return fmt.Errorf("%w", ErrNoNetworkProtocol)
|
||||
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
|
||||
return errors.New("wireguard public key field is empty")
|
||||
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
@@ -154,11 +155,11 @@ func (a *AllServers) Count() (count int) {
|
||||
type Servers struct {
|
||||
Version uint16 `json:"version"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Preferred bool `json:"preferred,omitempty"`
|
||||
Filepath string `json:"filepath,omitempty"`
|
||||
Servers []Server `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
var ErrServersFormatNotSupported = errors.New("servers format not supported")
|
||||
|
||||
func (s *Servers) Format(vpnProvider, format string) (formatted string, err error) {
|
||||
switch format {
|
||||
case "markdown":
|
||||
@@ -166,7 +167,7 @@ func (s *Servers) Format(vpnProvider, format string) (formatted string, err erro
|
||||
case "json":
|
||||
return s.toJSON()
|
||||
default:
|
||||
return "", fmt.Errorf("servers format not supported: %s", format)
|
||||
return "", fmt.Errorf("%w: %s", ErrServersFormatNotSupported, format)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
allServers *AllServers
|
||||
dataString string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no provider": {
|
||||
@@ -57,18 +58,16 @@ func Test_AllServers_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data, err := testCase.allServers.MarshalJSON()
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
data, err = json.Marshal(testCase.allServers)
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
@@ -88,6 +87,7 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
dataString string
|
||||
allServers AllServers
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
@@ -131,10 +131,9 @@ func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
||||
|
||||
err := json.Unmarshal(data, &allServers)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, testCase.allServers, allServers)
|
||||
})
|
||||
|
||||
+33
-16
@@ -6,40 +6,48 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrRequestSizeTooSmall = errors.New("message size is too small")
|
||||
|
||||
func checkRequest(request []byte) (err error) {
|
||||
const minMessageSize = 2 // version number + operation code
|
||||
if len(request) < minMessageSize {
|
||||
return fmt.Errorf("message size is too small: need at least %d bytes and got %d byte(s)",
|
||||
minMessageSize, len(request))
|
||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||
ErrRequestSizeTooSmall, minMessageSize, len(request))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrResponseSizeTooSmall = errors.New("response size is too small")
|
||||
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
|
||||
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
|
||||
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
|
||||
)
|
||||
|
||||
func checkResponse(response []byte, expectedOperationCode byte,
|
||||
expectedResponseSize uint,
|
||||
) (err error) {
|
||||
const minResponseSize = 4
|
||||
if len(response) < minResponseSize {
|
||||
return fmt.Errorf("response size is too small: "+
|
||||
"need at least %d bytes and got %d byte(s)",
|
||||
minResponseSize, len(response))
|
||||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||
ErrResponseSizeTooSmall, minResponseSize, len(response))
|
||||
}
|
||||
|
||||
if uint(len(response)) != expectedResponseSize {
|
||||
return fmt.Errorf("response size is unexpected: "+
|
||||
"expected %d bytes and got %d byte(s)",
|
||||
expectedResponseSize, len(response))
|
||||
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
|
||||
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
|
||||
}
|
||||
|
||||
protocolVersion := response[0]
|
||||
if protocolVersion != 0 {
|
||||
return fmt.Errorf("protocol version is unknown: %d", protocolVersion)
|
||||
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
|
||||
}
|
||||
|
||||
operationCode := response[1]
|
||||
if operationCode != expectedOperationCode {
|
||||
return fmt.Errorf("operation code is unexpected: expected 0x%x and got 0x%x", expectedOperationCode, operationCode)
|
||||
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
|
||||
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
|
||||
}
|
||||
|
||||
resultCode := binary.BigEndian.Uint16(response[2:4])
|
||||
@@ -51,6 +59,15 @@ func checkResponse(response []byte, expectedOperationCode byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVersionNotSupported = errors.New("version is not supported")
|
||||
ErrNotAuthorized = errors.New("not authorized")
|
||||
ErrNetworkFailure = errors.New("network failure")
|
||||
ErrOutOfResources = errors.New("out of resources")
|
||||
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
|
||||
ErrResultCodeUnknown = errors.New("result code is unknown")
|
||||
)
|
||||
|
||||
// checkResultCode checks the result code and returns an error
|
||||
// if the result code is not a success (0).
|
||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
|
||||
@@ -61,16 +78,16 @@ func checkResultCode(resultCode uint16) (err error) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return errors.New("version is not supported")
|
||||
return fmt.Errorf("%w", ErrVersionNotSupported)
|
||||
case 2:
|
||||
return errors.New("not authorized")
|
||||
return fmt.Errorf("%w", ErrNotAuthorized)
|
||||
case 3:
|
||||
return errors.New("network failure")
|
||||
return fmt.Errorf("%w", ErrNetworkFailure)
|
||||
case 4:
|
||||
return errors.New("out of resources")
|
||||
return fmt.Errorf("%w", ErrOutOfResources)
|
||||
case 5:
|
||||
return errors.New("operation code is not supported")
|
||||
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
|
||||
default:
|
||||
return fmt.Errorf("result code is unknown: %d", resultCode)
|
||||
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package natpmp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,10 +11,12 @@ func Test_checkRequest(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
request []byte
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"too_short": {
|
||||
request: []byte{1},
|
||||
err: ErrRequestSizeTooSmall,
|
||||
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
|
||||
},
|
||||
"success": {
|
||||
@@ -29,10 +30,9 @@ func Test_checkRequest(t *testing.T) {
|
||||
|
||||
err := checkRequest(testCase.request)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -50,33 +50,33 @@ func Test_checkResponse(t *testing.T) {
|
||||
}{
|
||||
"too_short": {
|
||||
response: []byte{1},
|
||||
err: errors.New("response size is too small"),
|
||||
err: ErrResponseSizeTooSmall,
|
||||
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
|
||||
},
|
||||
"size_mismatch": {
|
||||
response: []byte{0, 0, 0, 0},
|
||||
expectedResponseSize: 5,
|
||||
err: errors.New("response size is unexpected"),
|
||||
err: ErrResponseSizeUnexpected,
|
||||
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
|
||||
},
|
||||
"protocol_unknown": {
|
||||
response: []byte{1, 0, 0, 0},
|
||||
expectedResponseSize: 4,
|
||||
err: errors.New("protocol version is unknown"),
|
||||
err: ErrProtocolVersionUnknown,
|
||||
errMessage: "protocol version is unknown: 1",
|
||||
},
|
||||
"operation_code_unexpected": {
|
||||
response: []byte{0, 2, 0, 0},
|
||||
expectedOperationCode: 1,
|
||||
expectedResponseSize: 4,
|
||||
err: errors.New("operation code is unexpected"),
|
||||
err: ErrOperationCodeUnexpected,
|
||||
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
|
||||
},
|
||||
"result_code_failure": {
|
||||
response: []byte{0, 1, 0, 1},
|
||||
expectedOperationCode: 1,
|
||||
expectedResponseSize: 4,
|
||||
err: errors.New("version is not supported"),
|
||||
err: ErrVersionNotSupported,
|
||||
errMessage: "result code: version is not supported",
|
||||
},
|
||||
"success": {
|
||||
@@ -94,11 +94,9 @@ func Test_checkResponse(t *testing.T) {
|
||||
testCase.expectedOperationCode,
|
||||
testCase.expectedResponseSize)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
assert.ErrorContains(t, err, testCase.err.Error())
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -115,32 +113,32 @@ func Test_checkResultCode(t *testing.T) {
|
||||
"success": {},
|
||||
"version_unsupported": {
|
||||
resultCode: 1,
|
||||
err: errors.New("version is not supported"),
|
||||
err: ErrVersionNotSupported,
|
||||
errMessage: "version is not supported",
|
||||
},
|
||||
"not_authorized": {
|
||||
resultCode: 2,
|
||||
err: errors.New("not authorized"),
|
||||
err: ErrNotAuthorized,
|
||||
errMessage: "not authorized",
|
||||
},
|
||||
"network_failure": {
|
||||
resultCode: 3,
|
||||
err: errors.New("network failure"),
|
||||
err: ErrNetworkFailure,
|
||||
errMessage: "network failure",
|
||||
},
|
||||
"out_of_resources": {
|
||||
resultCode: 4,
|
||||
err: errors.New("out of resources"),
|
||||
err: ErrOutOfResources,
|
||||
errMessage: "out of resources",
|
||||
},
|
||||
"unsupported_operation_code": {
|
||||
resultCode: 5,
|
||||
err: errors.New("operation code is not supported"),
|
||||
err: ErrOperationCodeNotSupported,
|
||||
errMessage: "operation code is not supported",
|
||||
},
|
||||
"unknown": {
|
||||
resultCode: 6,
|
||||
err: errors.New("result code is unknown"),
|
||||
err: ErrResultCodeUnknown,
|
||||
errMessage: "result code is unknown: 6",
|
||||
},
|
||||
}
|
||||
@@ -151,11 +149,9 @@ func Test_checkResultCode(t *testing.T) {
|
||||
|
||||
err := checkResultCode(testCase.resultCode)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.err)
|
||||
if testCase.err != nil {
|
||||
assert.ErrorContains(t, err, testCase.err.Error())
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,11 +3,17 @@ package natpmp
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
|
||||
ErrLifetimeTooLong = errors.New("lifetime is too long")
|
||||
)
|
||||
|
||||
// Add or delete a port mapping. To delete a mapping, set both the
|
||||
// requestedExternalPort and lifetime to 0.
|
||||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
|
||||
@@ -20,9 +26,8 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
||||
lifetimeSecondsFloat := lifetime.Seconds()
|
||||
const maxLifetimeSeconds = uint64(^uint32(0))
|
||||
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
|
||||
return 0, 0, 0, 0, fmt.Errorf("lifetime is too long: "+
|
||||
"%d seconds must at most %d seconds",
|
||||
uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds",
|
||||
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||
}
|
||||
const messageSize = 12
|
||||
message := make([]byte, messageSize)
|
||||
@@ -33,7 +38,7 @@ func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
||||
case "tcp":
|
||||
message[1] = 2 // operationCode 2
|
||||
default:
|
||||
return 0, 0, 0, 0, fmt.Errorf("network protocol is unknown: %s", protocol)
|
||||
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol)
|
||||
}
|
||||
// [2:3] are reserved.
|
||||
binary.BigEndian.PutUint16(message[4:6], internalPort)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user