Compare commits

..

44 Commits

Author SHA1 Message Date
Quentin McGaw 106a4fdf58 Merge branch 'master' into restrictednet 2026-06-11 14:33:35 +00:00
Quentin McGaw 8abb05567c hotfix(command): fix unit test 2026-06-11 14:06:26 +00:00
Quentin McGaw f6b2612923 Merge branch 'master' into restrictednet 2026-06-11 14:01:08 +00:00
Quentin McGaw 08dfd73367 pr review feedback 2026-06-11 14:01:05 +00:00
Quentin McGaw a53a0267e4 hotfix(socks5): support domain name udp association 2026-06-11 13:50:50 +00:00
Quentin McGaw 4e986c8af7 chore(socks5): fix lint errors on integration test 2026-06-11 13:37:58 +00:00
Quentin McGaw b44c671217 lint fix 2026-06-11 13:36:08 +00:00
Quentin McGaw 6d84462f00 feat(socks5): UDP proxying (#3353) 2026-06-11 15:32:38 +02:00
Quentin McGaw acab89b91a fix(command): wait for all stdout and stderr streams to complete correctly 2026-06-11 13:30:59 +00:00
Quentin McGaw 48c1f2bf6a chore(lint): run linter on integration tests 2026-06-11 13:29:57 +00:00
Quentin McGaw 70d80f7473 context aware connectFD 2026-06-11 13:06:05 +00:00
Quentin McGaw 9af6aaff27 PR feedback 2026-06-11 01:17:55 +00:00
Quentin McGaw d28744e06d pr review changes 2026-06-11 00:16:32 +00:00
Quentin McGaw 69b4e5c584 PR feedback fixes 2026-06-09 21:11:15 +00:00
Quentin McGaw 29186feccc Fix ordering in cleanup function 2026-06-09 14:07:05 +00:00
Quentin McGaw b5366b9e44 Change tests to be more integration oriented 2026-06-09 14:05:30 +00:00
Quentin McGaw dd07205b85 add tests 2026-06-09 12:47:13 +00:00
Quentin McGaw e2256dd1b2 moare fixes 2026-06-05 15:52:51 +00:00
Quentin McGaw c599e7fd2c chore(ci): disabe workflow concurrency by workflow-[pr|ref] 2026-06-05 15:50:01 +00:00
Quentin McGaw 8da913d7c6 context aware connectSourceConnection 2026-06-05 15:35:28 +00:00
Quentin McGaw 2d2c371303 pr review fixes 2026-06-05 15:25:44 +00:00
Quentin McGaw b48ba8cb0a review feedback 2026-06-05 05:01:18 +00:00
Quentin McGaw c18c54c3b7 Fix test to use a random port and not 443 2026-06-05 04:58:47 +00:00
Quentin McGaw 820689cc23 imporatnt fix 2 2026-06-05 04:46:20 +00:00
Quentin McGaw a9a36644ec imporatnt fix 1 2026-06-05 04:46:16 +00:00
Quentin McGaw fad8c9889a Minor fixes 2026-06-05 04:21:53 +00:00
Quentin McGaw aa781c6cc5 initial 2026-06-05 03:56:25 +00:00
Quentin McGaw ff6e45fae0 chore(ci): disable PIA end to end testing due to expired credentials 2026-06-04 16:52:53 +00:00
ligistx 17f24343d6 fix(providers/custom): use proto tcp-client instead of proto tcp (#3350) 2026-05-25 18:07:35 +02:00
Quentin McGaw ebbc630b31 chore(storage): remove servers.json in favor of just code at runtime 2026-05-24 22:22:41 +00:00
Quentin McGaw 39ac8b3432 hotfix(updater): use DoH for all updating operations, not just resolving server hostnames 2026-05-24 21:46:22 +00:00
Quentin McGaw f65ee3dcb1 hotfix(github): fix dependabot config (AI at it again) 2026-05-24 21:22:18 +00:00
dependabot[bot] 7e8d81b161 Chore(deps): Bump golang.org/x/net from 0.51.0 to 0.55.0 (#3338) 2026-05-24 23:09:52 +02:00
Quentin McGaw 21e868c89c hotfix(protonvpn): small port forwarding fixes for edge cases 2026-05-24 21:08:56 +00:00
Quentin McGaw 2e20e2df66 feat(protonvpn): use symmetric port forwarding for first port then asymmetric for next ports (#3345) 2026-05-24 22:47:58 +02:00
Quentin McGaw 6f5f518d1d chore(github): finer grain schedules for dependency checking
- default to weekly instead of daily
- check gluetun-servers daily
- check some Go modules only quartely since they are not important
2026-05-24 20:34:57 +00:00
Quentin McGaw 1998e0d04f chore(deps): remove direct dependency on golang.org/x/exp 2026-05-24 20:28:54 +00:00
Quentin McGaw 14f30bc641 docs(maintenance): clear up some finished items 2026-05-24 20:18:27 +00:00
Quentin McGaw f89e55b8ff chore(storage): remove outdated servers.json CI and documentation 2026-05-24 20:18:07 +00:00
Quentin McGaw 7ad6af0947 docs(github): remove servers.json checkbox from PR template 2026-05-24 20:13:07 +00:00
Quentin McGaw d3e089ccd7 hotfix(firewall/iptables): filter out DOCKER* chains from nat table when saving/restoring 2026-05-23 21:44:22 +00:00
Quentin McGaw 3eebbf65a8 hotfix(firewall/iptables): only restore firewall if IPv6 port redirection failed but NAT is supported 2026-05-23 21:26:08 +00:00
Quentin McGaw a1ef736b0f hotfix(portforwarding): disallow setting ports when running port forwarding code 2026-05-23 13:20:20 +00:00
Quentin McGaw 46edfe49e3 fix(portforwarding): handle empty ports without panicing 2026-05-23 13:19:37 +00:00
81 changed files with 2747 additions and 1470 deletions
-1
View File
@@ -56,7 +56,6 @@ body:
- IVPN
- Mullvad
- NordVPN
- OVPN
- Privado
- Private Internet Access
- PrivateVPN
+31 -5
View File
@@ -4,12 +4,38 @@ updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
interval: "weekly"
- package-ecosystem: docker
directory: /
schedule:
interval: "daily"
- package-ecosystem: gomod
directory: /
interval: "weekly"
- # Servers data dependency that should be updated as soon as
# possible when a new version is released, to have the latest
# servers available
package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"
interval: "weekly"
ignore:
- # In particular avoid amneziawg-go which have v1.x.y versions available
# on the Go modules proxy, but are not in the Github repository tags
# and are not the latest releases either. Most likely a mistake from the
# maintainers, which is persisted on the Go proxy.
dependency-name: "github.com/amnezia-vpn/amneziawg-go"
versions: ["1.x"]
groups:
low-importance:
patterns:
- "github.com/breml/rootcerts"
- "github.com/fatih/color"
- "github.com/golang/mock"
- "github.com/klauspost/compress"
- "github.com/klauspost/pgzip"
- "github.com/pelletier/go-toml/v2"
- "github.com/qdm12/goshutdown"
- "github.com/qdm12/gosplash"
- "github.com/qdm12/gotree"
- "github.com/qdm12/log"
- "github.com/stretchr/testify"
- "github.com/ulikunitz/xz"
- "gopkg.in/ini.v1"
-2
View File
@@ -64,8 +64,6 @@
color: "cfe8d4"
- name: "☁️ NordVPN"
color: "cfe8d4"
- name: "☁️ OVPN"
color: "cfe8d4"
- name: "☁️ Perfect Privacy"
color: "cfe8d4"
- name: "☁️ PIA"
-1
View File
@@ -8,5 +8,4 @@
# Assertions
* [ ] I am aware that we do not accept manual changes to the servers.json file <!-- If this is your goal, please consult https://github.com/qdm12/gluetun-wiki/blob/main/setup/servers.md#update-using-the-command-line -->
* [ ] I am aware that any changes to settings should be reflected in the [wiki](https://github.com/qdm12/gluetun-wiki/)
+20 -10
View File
@@ -28,6 +28,10 @@ on:
- go.mod
- go.sum
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
verify:
runs-on: ubuntu-latest
@@ -44,7 +48,6 @@ jobs:
locale: "US"
level: error
exclude: |
./internal/storage/servers.json
./.golangci.yml
*.md
@@ -64,6 +67,10 @@ jobs:
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
- name: Run integration tests in test container
run: |
docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet
- name: Verify dev cross platform compatibility
run: docker build --target xcompile .
@@ -98,7 +105,7 @@ jobs:
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
)
needs: [ verify ]
needs: [verify]
runs-on: ubuntu-latest
environment: secrets
steps:
@@ -120,7 +127,8 @@ jobs:
- name: Run Gluetun container with ProtonVPN Wireguard and port forwarding
configuration
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner
run:
echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner
protonvpn-wireguard-port-forwarding
- name: Run Gluetun container with ProtonVPN OpenVPN and port forwarding
@@ -129,11 +137,12 @@ jobs:
secrets.PROTONVPN_OPENVPN_PASSWORD }}" | ./ci/runner
protonvpn-openvpn-port-forwarding
- name: Run Gluetun container with Private Internet Access OpenVPN and port
forwarding configuration
run: echo -e "${{ secrets.PRIVATEINTERNETACCESS_OPENVPN_USER }}\n${{
secrets.PRIVATEINTERNETACCESS_OPENVPN_PASSWORD }}" | ./ci/runner
private-internet-access-openvpn-port-forwarding
# - name:
# Run Gluetun container with Private Internet Access OpenVPN and port
# forwarding configuration
# run: echo -e "${{ secrets.PRIVATEINTERNETACCESS_OPENVPN_USER }}\n${{
# secrets.PRIVATEINTERNETACCESS_OPENVPN_PASSWORD }}" | ./ci/runner
# private-internet-access-openvpn-port-forwarding
- name: Run Gluetun container with AirVPN Wireguard configuration
run: echo -e "${{ secrets.AIRVPN_WIREGUARD_PRIVATE_KEY }}\n${{
@@ -141,7 +150,8 @@ jobs:
secrets.AIRVPN_WIREGUARD_ADDRESSES }}" | ./ci/runner airvpn-wireguard
- name: Run Gluetun container with AirVPN OpenVPN configuration
run: echo -e "${{ secrets.AIRVPN_OPENVPN_KEY }}\n${{ secrets.AIRVPN_OPENVPN_CERT
run:
echo -e "${{ secrets.AIRVPN_OPENVPN_KEY }}\n${{ secrets.AIRVPN_OPENVPN_CERT
}}" | ./ci/runner airvpn-openvpn
codeql:
@@ -169,7 +179,7 @@ jobs:
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
)
needs: [ verify, verify-private, codeql ]
needs: [verify, verify-private, codeql]
permissions:
actions: read
contents: read
+4
View File
@@ -11,6 +11,10 @@ on:
- "**.md"
- .github/workflows/markdown.yml
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
markdown:
runs-on: ubuntu-latest
-98
View File
@@ -1,98 +0,0 @@
name: Update servers list
on:
workflow_dispatch:
inputs:
provider:
description: "VPN Provider to update"
required: true
default: "all"
type: choice
options:
- all
- airvpn
- cyberghost
- expressvpn
- fastestvpn
- giganews
- hidemyass
- ipvanish
- ivpn
- mullvad
- nordvpn
- perfect privacy
- privado
- private internet access
- privatevpn
- protonvpn
- purevpn
- slickvpn
- surfshark
- torguard
- vpnsecure
- vpn unlimited
- vyprvpn
- windscribe
schedule:
- cron: "11 3 1 */2 *" # Run at 03:11 on the 1st of every 2nd month
jobs:
update-servers-list:
if: github.repository == 'passteque/gluetun'
runs-on: ubuntu-latest
permissions:
actions: read
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version-file: go.mod
- name: Update servers list
run: |
SELECTED_PROVIDER="${{ github.event.inputs.provider || 'all' }}"
if [ "$SELECTED_PROVIDER" = "all" ]; then
FLAGS="-all"
else
FLAGS="-providers $SELECTED_PROVIDER"
fi
go run ./cmd/gluetun/main.go update $FLAGS \
-maintainer \
-proton-email "${{ secrets.PROTON_EMAIL }}" \
-proton-password "${{ secrets.PROTON_PASSWORD }}"
- name: Check for changes
run: |
if git diff --exit-code internal/storage/servers.json >/dev/null; then
echo "Error: internal/storage/servers.json was not modified."
exit 1
fi
- name: Check no other file changes
run: |
if ! git diff --exit-code --quiet ':!internal/storage/servers.json'; then
echo "Error: Unexpected changes detected in files other than servers.json"
git status --short
exit 1
fi
- name: Create Pull Request
id: createpr
uses: peter-evans/create-pull-request@v8
with:
branch-suffix: timestamp
branch: bot/update-servers-list
base: master
delete-branch: true
title: "feat(providers/${{ github.event.inputs.provider || 'all' }}): servers data update"
body: |
This PR was automatically generated by the [Update servers list](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) workflow run.
# - name: Merge Pull Request
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# run: |
# gh pr merge ${{ steps.createpr.outputs.pull-request-number }} --auto -m -d
+4
View File
@@ -12,6 +12,10 @@ formatters:
- builtin$
- examples$
run:
build-tags:
- integration
linters:
settings:
misspell:
+1 -1
View File
@@ -3,7 +3,7 @@
// to develop this project.
"files.eol": "\n",
"editor.formatOnSave": true,
"go.buildTags": "linux",
"go.buildTags": "linux,integration",
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
+5
View File
@@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
- Use `netip` types instead of `net` types whenever possible
- Use constants instead of variables whenever possible, especially function-local inline constants.
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
- `panic`:
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
@@ -115,6 +116,7 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool.
- **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example.
- **Always** set the `.Return(...)` on the mock if the function returns something.
- Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency
- Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests.
### main.go
@@ -127,6 +129,7 @@ The Go formatter used is gofumpt.
### Errors
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
- In rare cases, you can just use `return err` notably:
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
- If the current function only statement is the call to another function, for example:
@@ -179,6 +182,8 @@ The Go formatter used is gofumpt.
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
## Validation checklist
+2 -4
View File
@@ -186,14 +186,12 @@ ENV VPN_SERVICE_PROVIDER=pia \
# # ProtonVPN only:
SECURE_CORE_ONLY= \
TOR_ONLY= \
# # Surfshark and ovpn only:
# # Surfshark only:
MULTIHOP_ONLY= \
# # VPN Secure only:
PREMIUM_ONLY= \
# # PIA and ProtonVPN only:
PORT_FORWARD_ONLY= \
# # Ovpn only:
SERVER_DEDICATED=no \
# Firewall
FIREWALL_ENABLED_DISABLING_IT_SHOOTS_YOU_IN_YOUR_FOOT=on \
FIREWALL_VPN_INPUT_PORTS= \
@@ -278,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
PUID=1000 \
PGID=1000
ENTRYPOINT ["/gluetun-entrypoint"]
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp 1080/udp
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck
ARG TARGETPLATFORM
RUN apk add --no-cache --update -l wget && \
+3 -3
View File
@@ -60,10 +60,10 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
## Features
- Based on Alpine 3.23 for a small Docker image of 43.1MB
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Ovpn**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **Giganews**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad** (Wireguard only), **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports OpenVPN for all providers listed
- Supports Wireguard both kernelspace and userspace
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Ovpn**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **VyprVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- More in progress, see [#134](https://github.com/passteque/gluetun/issues/134)
@@ -73,7 +73,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
- Choose the vpn network protocol, `udp` or `tcp`
- Built in firewall kill switch to allow traffic only with needed the VPN servers and LAN devices
- Built in Shadowsocks proxy server (protocol based on SOCKS5 with an encryption layer, tunnels TCP+UDP)
- Built in Socks5 proxy server (tunnels TCP) - partial credits to @angelakis and @adjscent
- Built in Socks5 proxy server (tunnels TCP+UDP) - partial credits to @angelakis and @adjscent
- Built in HTTP proxy (tunnels HTTP and HTTPS through TCP)
- [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md)
- [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md)
+8 -8
View File
@@ -15,7 +15,7 @@ require (
github.com/mdlayher/netlink v1.9.0
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a
github.com/qdm12/gluetun-servers v0.1.1-0.20260522005421-14277e92ce82
github.com/qdm12/gluetun-servers v0.1.0
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978
github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0
@@ -27,10 +27,9 @@ require (
github.com/ti-mo/netfilter v0.5.3
github.com/ulikunitz/xz v0.5.15
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.51.0
golang.org/x/sys v0.42.0
golang.org/x/text v0.35.0
golang.org/x/net v0.55.0
golang.org/x/sys v0.45.0
golang.org/x/text v0.37.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.1
@@ -57,10 +56,11 @@ require (
github.com/prometheus/common v0.60.1 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/crypto v0.51.0 // indirect
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.org/x/tools v0.44.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+14 -14
View File
@@ -76,8 +76,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a h1:TE157yPQmAbVruH0MWCQzs0vTT/6t96DkoWUXd6PVuc=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/gluetun-servers v0.1.1-0.20260522005421-14277e92ce82 h1:tE44IEW7o9yPQaO8HBeoO9RxtTTxqhboIypegrQlVt8=
github.com/qdm12/gluetun-servers v0.1.1-0.20260522005421-14277e92ce82/go.mod h1:acttuyHyoFDu6GTbf3kAV+QXeiX8oJeh0MBic67/9z8=
github.com/qdm12/gluetun-servers v0.1.0 h1:w9JLghKZwI0Gzpp9p5rNANgEYUUZ1dxdxsG6NKIojaY=
github.com/qdm12/gluetun-servers v0.1.0/go.mod h1:acttuyHyoFDu6GTbf3kAV+QXeiX8oJeh0MBic67/9z8=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
@@ -120,15 +120,15 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -136,8 +136,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -156,8 +156,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -167,8 +167,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -176,8 +176,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+5
View File
@@ -5,6 +5,7 @@ import (
"errors"
"flag"
"fmt"
"net"
"net/http"
"slices"
"strings"
@@ -104,6 +105,10 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
if err != nil {
return fmt.Errorf("creating DoH dialer: %w", err)
}
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: dnsDialer.Dial,
}
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
+10 -4
View File
@@ -9,8 +9,9 @@ import (
)
// Start launches a command and streams stdout and stderr to channels.
// All the channels returned are ready only and won't be closed
// if the command fails later.
// stdoutLines and stderrLines channels will be closed when there is no more
// output to read, in order for the caller to catch all lines even after the
// command has finished. The waitError channel returned will never be closed.
func (c *Cmder) Start(cmd *exec.Cmd) (
stdoutLines, stderrLines <-chan string,
waitError <-chan error, startErr error,
@@ -38,6 +39,7 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
if err != nil {
_ = stdout.Close()
<-stdoutDone
close(stdoutLinesCh)
return nil, nil, nil, err
}
go streamToChannel(stderrReady, stderrDone, stderr, stderrLinesCh)
@@ -45,9 +47,11 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
err = cmd.Start()
if err != nil {
_ = stdout.Close()
_ = stderr.Close()
<-stdoutDone
close(stdoutLinesCh)
_ = stderr.Close()
<-stderrDone
close(stderrLinesCh)
return nil, nil, nil, err
}
@@ -55,8 +59,10 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
go func() {
err := cmd.Wait()
<-stdoutDone
<-stderrDone
close(stdoutLinesCh)
_ = stdout.Close()
<-stderrDone
close(stderrLinesCh)
_ = stderr.Close()
waitErrorCh <- err
}()
+42 -24
View File
@@ -89,30 +89,48 @@ func Test_start(t *testing.T) {
require.NoError(t, err)
var stdoutIndex, stderrIndex int
done := false
for !done {
select {
case line := <-stdoutLines:
assert.Equal(t, testCase.stdout[stdoutIndex], line)
stdoutIndex++
case line := <-stderrLines:
assert.Equal(t, testCase.stderr[stderrIndex], line)
stderrIndex++
case err := <-waitError:
if testCase.waitErr != nil {
require.Error(t, err)
assert.Equal(t, testCase.waitErr.Error(), err.Error())
} else {
assert.NoError(t, err)
}
done = true
}
}
assert.Equal(t, len(testCase.stdout), stdoutIndex)
assert.Equal(t, len(testCase.stderr), stderrIndex)
collectAndCheckChannels(t, stdoutLines, stderrLines, waitError,
testCase.stdout, testCase.stderr, testCase.waitErr)
})
}
}
func collectAndCheckChannels(t *testing.T, stdoutLines, stderrLines <-chan string,
waitError <-chan error, expectedStdout, expectedStderr []string, expectedWaitErr error,
) {
t.Helper()
stdoutIndex := 0
stderrIndex := 0
done := false
for !done {
select {
case line, ok := <-stdoutLines:
if !ok {
stdoutLines = nil
continue
}
assert.Equal(t, expectedStdout[stdoutIndex], line)
stdoutIndex++
case line, ok := <-stderrLines:
if !ok {
stderrLines = nil
continue
}
assert.Equal(t, expectedStderr[stderrIndex], line)
stderrIndex++
case err := <-waitError:
if expectedWaitErr != nil {
require.Error(t, err)
assert.Equal(t, expectedWaitErr.Error(), err.Error())
} else {
assert.NoError(t, err)
}
done = true
}
}
assert.Equal(t, len(expectedStdout), stdoutIndex)
assert.Equal(t, len(expectedStderr), stderrIndex)
}
+19 -13
View File
@@ -18,31 +18,37 @@ func (c *Cmder) RunAndLog(ctx context.Context, command string, logger Logger) (e
return err
}
streamCtx, streamCancel := context.WithCancel(context.Background())
streamDone := make(chan struct{})
go streamLines(streamCtx, streamDone, logger, stdout, stderr)
go streamLines(streamDone, logger, stdout, stderr)
err = <-waitError
streamCancel()
<-streamDone
return err
}
func streamLines(ctx context.Context, done chan<- struct{},
logger Logger, stdout, stderr <-chan string,
func streamLines(done chan<- struct{}, logger Logger,
stdout, stderr <-chan string,
) {
defer close(done)
var line string
for {
select {
case <-ctx.Done():
return
case line = <-stdout:
logger.Info(line)
case line = <-stderr:
logger.Error(line)
case line, ok := <-stdout:
if ok {
logger.Info(line)
}
if stderr == nil {
return
}
stdout = nil
case line, ok := <-stderr:
if ok {
logger.Error(line)
}
if stdout == nil {
return
}
stderr = nil
}
}
}
@@ -1,10 +1,10 @@
package settings
import (
"maps"
"slices"
"github.com/qdm12/gosettings/reader"
"golang.org/x/exp/maps"
)
func readObsolete(r *reader.Reader) (warnings []string) {
@@ -17,7 +17,7 @@ func readObsolete(r *reader.Reader) (warnings []string) {
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because you should use the built-in server which now " +
"forwards local names to private DNS resolvers found in /etc/resolv.conf at container start",
}
sortedKeys := maps.Keys(keyToMessage)
sortedKeys := slices.Collect(maps.Keys(keyToMessage))
slices.Sort(sortedKeys)
warnings = make([]string, 0, len(keyToMessage))
for _, key := range sortedKeys {
@@ -70,7 +70,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
switch vpnProvider {
// no restriction on port
case providers.Custom, providers.Cyberghost, providers.HideMyAss,
providers.Ovpn, providers.Privatevpn, providers.Torguard:
providers.Privatevpn, providers.Torguard:
// no custom port allowed
case providers.Expressvpn, providers.Fastestvpn,
providers.Giganews, providers.Ipvanish,
@@ -95,7 +95,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
return errors.New("port forwarding password is empty")
}
case providers.Protonvpn:
const maxPortsCount = 4
const maxPortsCount = 5
if p.PortsCount > maxPortsCount {
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
}
@@ -49,7 +49,6 @@ func (p *Provider) validate(vpnType string, filterChoicesGetter FilterChoicesGet
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.Ovpn,
providers.Protonvpn,
providers.Surfshark,
providers.Windscribe,
@@ -63,9 +63,6 @@ type ServerSelection struct {
// TorOnly is true if VPN servers without tor should
// be filtered. This is used with ProtonVPN.
TorOnly *bool `json:"tor_only"`
// Dedicated is true if dedicated VPN servers should be chosen only.
// This is used with OVPN.
Dedicated *bool `json:"dedicated"`
// OpenVPN contains settings to select OpenVPN servers
// and the final connection.
OpenVPN OpenVPNSelection `json:"openvpn"`
@@ -275,8 +272,6 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
return errors.New("secure core only filter is not supported")
case *settings.TorOnly && vpnServiceProvider != providers.Protonvpn:
return errors.New("tor only filter is not supported")
case *settings.Dedicated && vpnServiceProvider != providers.Ovpn:
return errors.New("dedicated filter is not supported")
default:
return nil
}
@@ -301,7 +296,6 @@ func (ss *ServerSelection) copy() (copied ServerSelection) {
TorOnly: gosettings.CopyPointer(ss.TorOnly),
PortForwardOnly: gosettings.CopyPointer(ss.PortForwardOnly),
MultiHopOnly: gosettings.CopyPointer(ss.MultiHopOnly),
Dedicated: gosettings.CopyPointer(ss.Dedicated),
OpenVPN: ss.OpenVPN.copy(),
Wireguard: ss.Wireguard.copy(),
}
@@ -325,7 +319,6 @@ func (ss *ServerSelection) overrideWith(other ServerSelection) {
ss.TorOnly = gosettings.OverrideWithPointer(ss.TorOnly, other.TorOnly)
ss.MultiHopOnly = gosettings.OverrideWithPointer(ss.MultiHopOnly, other.MultiHopOnly)
ss.PortForwardOnly = gosettings.OverrideWithPointer(ss.PortForwardOnly, other.PortForwardOnly)
ss.Dedicated = gosettings.OverrideWithPointer(ss.Dedicated, other.Dedicated)
ss.OpenVPN.overrideWith(other.OpenVPN)
ss.Wireguard.overrideWith(other.Wireguard)
}
@@ -342,7 +335,6 @@ func (ss *ServerSelection) setDefaults(vpnProvider string, portForwardingEnabled
defaultPortForwardOnly := portForwardingEnabled &&
helpers.IsOneOf(vpnProvider, providers.PrivateInternetAccess, providers.Protonvpn)
ss.PortForwardOnly = gosettings.DefaultPointer(ss.PortForwardOnly, defaultPortForwardOnly)
ss.Dedicated = gosettings.DefaultPointer(ss.Dedicated, false)
ss.OpenVPN.setDefaults(vpnProvider)
ss.Wireguard.setDefaults()
}
@@ -418,10 +410,6 @@ func (ss ServerSelection) toLinesNode() (node *gotree.Node) {
node.Appendf("Multi-hop only servers: yes")
}
if *ss.Dedicated {
node.Appendf("Dedicated servers: yes")
}
if *ss.PortForwardOnly {
node.Appendf("Port forwarding only servers: yes")
}
@@ -513,12 +501,6 @@ func (ss *ServerSelection) read(r *reader.Reader,
return err
}
// Ovpn only
ss.Dedicated, err = r.BoolPtr("SERVER_DEDICATED")
if err != nil {
return err
}
err = ss.OpenVPN.read(r)
if err != nil {
return err
@@ -5,7 +5,6 @@ import (
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
@@ -23,7 +22,7 @@ type WireguardSelection struct {
// It can never be the zero value in the internal state.
EndpointIP netip.Addr `json:"endpoint_ip"`
// EndpointPort is a the server port to use for the VPN server.
// It is optional for VPN providers IVPN, Mullvad, Ovpn, Surfshark
// It is optional for VPN providers IVPN, Mullvad, Surfshark
// and Windscribe, and compulsory for the others.
// When optional, it can be set to 0 to indicate not use
// a custom endpoint port. It cannot be nil in the internal
@@ -41,9 +40,8 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// Validate EndpointIP
switch vpnProvider {
case providers.Airvpn, providers.Fastestvpn, providers.Ivpn,
providers.Mullvad, providers.Nordvpn, providers.Ovpn,
providers.Protonvpn, providers.Surfshark,
providers.Windscribe:
providers.Mullvad, providers.Nordvpn, providers.Protonvpn,
providers.Surfshark, providers.Windscribe:
// endpoint IP addresses are baked in
case providers.Custom:
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
@@ -65,16 +63,12 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
if *w.EndpointPort != 0 {
return errors.New("endpoint port is set")
}
case providers.Airvpn, providers.Ivpn, providers.Mullvad,
providers.Ovpn, providers.Windscribe:
case providers.Airvpn, providers.Ivpn, providers.Mullvad, providers.Windscribe:
// EndpointPort is optional and can be 0
if *w.EndpointPort == 0 {
break // no custom endpoint port set
}
if helpers.IsOneOf(vpnProvider,
providers.Mullvad,
providers.Ovpn,
) {
if vpnProvider == providers.Mullvad {
break // no restriction on custom endpoint port value
}
var allowed []uint16
@@ -98,7 +92,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// Validate PublicKey
switch vpnProvider {
case providers.Fastestvpn, providers.Ivpn, providers.Mullvad,
providers.Ovpn, providers.Surfshark, providers.Windscribe:
providers.Surfshark, providers.Windscribe:
// public keys are baked in
case providers.Custom:
if w.PublicKey == "" {
@@ -15,7 +15,6 @@ const (
Ivpn = "ivpn"
Mullvad = "mullvad"
Nordvpn = "nordvpn"
Ovpn = "ovpn"
Perfectprivacy = "perfect privacy"
Privado = "privado"
PrivateInternetAccess = "private internet access"
@@ -44,7 +43,6 @@ func All() []string {
Ivpn,
Mullvad,
Nordvpn,
Ovpn,
Perfectprivacy,
Privado,
PrivateInternetAccess,
+2
View File
@@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutput(ctx context.Context, protocol, intf string,
ip netip.Addr, port uint16, remove bool) error
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
source, destination netip.AddrPort, remove bool) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
+16 -18
View File
@@ -1,7 +1,6 @@
package iptables
import (
"bufio"
"context"
"fmt"
"os/exec"
@@ -97,25 +96,24 @@ func saveData(ctx context.Context, binary string) (data string, err error) {
}
return "", fmt.Errorf("running %s-save: %w", binary, err)
}
err = checkData(string(output))
if err != nil {
return "", fmt.Errorf("checking saved data: %w", err)
}
return string(output), nil
return filterData(output)
}
func checkData(data string) error {
scanner := bufio.NewScanner(strings.NewReader(data))
i := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "[unsupported") {
return fmt.Errorf("unsupported revision marker found in line %d: %s", i+1, line)
func filterData(cmdOutput []byte) (filtered string, err error) {
lines := strings.Split(string(cmdOutput), "\n")
filteredLines := make([]string, 0, len(lines))
for _, line := range lines {
switch {
case strings.HasPrefix(line, ":DOCKER_OUTPUT"),
strings.HasPrefix(line, ":DOCKER_POSTROUTING"),
strings.HasPrefix(line, "-A DOCKER_OUTPUT"),
strings.HasPrefix(line, "-A DOCKER_POSTROUTING"):
// Do not touch (aka save and restore) NAT rules added by Docker
continue
case strings.Contains(line, "[unsupported revision]"):
return "", fmt.Errorf("mismatch container iptables-save and kernel: %s", line)
}
i++
filteredLines = append(filteredLines, line)
}
if scanner.Err() != nil {
return fmt.Errorf("scanning data: %w", scanner.Err())
}
return nil
return strings.Join(filteredLines, "\n"), nil
}
+25 -1
View File
@@ -2,6 +2,7 @@ package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
@@ -177,6 +178,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
if source.Addr().BitLen() != destination.Addr().BitLen() {
return errors.New("source and destination address families do not match")
}
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(),
protocol, protocol, source.Port(), destination.Port())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
@@ -278,7 +302,6 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
restore(ctx) // just in case
errMessage := err.Error()
if strings.Contains(errMessage, "can't initialize ip6tables table `nat': Table does not exist") {
if !remove {
@@ -286,6 +309,7 @@ func (c *Config) RedirectPort(ctx context.Context, intf string,
}
return nil
}
restore(ctx)
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
+7
View File
@@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
) error {
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
}
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
source, destination, remove)
}
-3
View File
@@ -34,11 +34,8 @@ type Server struct {
SecureCore bool `json:"secure_core,omitempty"`
Tor bool `json:"tor,omitempty"`
PortForward bool `json:"port_forward,omitempty"`
Dedicated bool `json:"dedicated,omitempty"`
Keep bool `json:"keep,omitempty"`
IPs []netip.Addr `json:"ips,omitempty"`
PortsTCP []uint16 `json:"ports_tcp,omitempty"`
PortsUDP []uint16 `json:"ports_udp,omitempty"`
}
func (s *Server) HasMinimumInformation() (err error) {
+1 -4
View File
@@ -29,19 +29,16 @@ func (r *Runner) Run(ctx context.Context, errCh chan<- error, ready chan<- struc
return
}
streamCtx, streamCancel := context.WithCancel(context.Background())
streamDone := make(chan struct{})
go streamLines(streamCtx, streamDone, r.logger,
go streamLines(streamDone, r.logger,
stdoutLines, stderrLines, ready)
select {
case <-ctx.Done():
<-waitError
streamCancel()
<-streamDone
errCh <- ctx.Err()
case err := <-waitError:
streamCancel()
<-streamDone
errCh <- err
}
+20 -9
View File
@@ -1,26 +1,37 @@
package openvpn
import (
"context"
"strings"
)
func streamLines(ctx context.Context, done chan<- struct{},
func streamLines(done chan<- struct{},
logger Logger, stdout, stderr <-chan string,
tunnelReady chan<- struct{},
) {
defer close(done)
var line string
for {
var line string
var ok bool
errLine := false
select {
case <-ctx.Done():
return
case line = <-stdout:
case line = <-stderr:
errLine = true
case line, ok = <-stdout:
if ok {
break
}
if stderr == nil {
return
}
stdout = nil
case line, ok = <-stderr:
if ok {
errLine = true
break
}
if stdout == nil {
return
}
stderr = nil
}
line, level := processLogLine(line)
if line == "" {
+5 -1
View File
@@ -15,7 +15,11 @@ func runCommand(ctx context.Context, cmder Cmder, logger Logger,
}
portsString := strings.Join(portStrings, ",")
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0])
var firstPort string
if len(portStrings) > 0 {
firstPort = portStrings[0]
}
commandString = strings.ReplaceAll(commandString, "{{PORT}}", firstPort)
commandString = strings.ReplaceAll(commandString, "{{VPN_INTERFACE}}", vpnInterface)
return cmder.RunAndLog(ctx, commandString, logger)
}
+5 -2
View File
@@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"net/http"
"slices"
@@ -59,6 +60,10 @@ func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err er
s.portMutex.Lock()
defer s.portMutex.Unlock()
if s.settings.PortForwarder != nil {
return errors.New("setting port forwarded at runtime is not supported with internally running port forwarding code")
}
slices.Sort(ports)
if slices.Equal(s.ports, ports) {
return nil
@@ -78,7 +83,5 @@ func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err er
return fmt.Errorf("handling new ports: %w", err)
}
s.logger.Info("updated: " + portsToString(s.ports))
return nil
}
+1 -1
View File
@@ -88,7 +88,7 @@ func (s *Settings) Validate(forStartup bool) (err error) {
return errors.New("password not set")
}
case providers.Protonvpn:
const maxPortsCount = 4
const maxPortsCount = 5
if s.PortsCount > maxPortsCount {
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
}
+4 -10
View File
@@ -92,13 +92,9 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui
s.logger.Info(portPairsToString(internalToExternalPorts))
externalPorts := slices.Collect(maps.Values(internalToExternalPorts))
autoRedirectionNeeded := false
externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts))
for internal, external := range internalToExternalPorts {
externalToInternalPorts[external] = internal
if internal != external {
autoRedirectionNeeded = true
}
}
slices.Sort(externalPorts)
userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0})
@@ -109,23 +105,21 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui
return fmt.Errorf("allowing port in firewall: %w", err)
}
var sourcePort, destinationPort uint16
var destinationPort uint16
switch {
case userRedirectionEnabled: // precedence over auto redirection
sourcePort = externalToInternalPorts[port]
destinationPort = s.settings.ListeningPorts[i]
case autoRedirectionNeeded:
sourcePort = externalToInternalPorts[port]
case port != internalPort: // auto redirection needed, source and destination ports differ
destinationPort = port
default:
// No redirection needed, source and destination ports are the same.
continue
}
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, sourcePort, destinationPort)
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, internalPort, destinationPort)
if err != nil {
return fmt.Errorf("redirecting port %d to %d in firewall: %w",
sourcePort, destinationPort, err)
internalPort, destinationPort, err)
}
}
+6 -1
View File
@@ -6,6 +6,7 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
@@ -65,7 +66,11 @@ func modifyConfig(lines []string, connection models.Connection,
}
// Add values
modified = append(modified, "proto "+connection.Protocol)
protocol := connection.Protocol
if protocol == constants.TCP {
protocol = "tcp-client"
}
modified = append(modified, "proto "+protocol)
modified = append(modified, fmt.Sprintf("remote %s %d", connection.IP, connection.Port))
modified = append(modified, "dev "+settings.Interface)
modified = append(modified, "mute-replay-warnings")
-15
View File
@@ -1,15 +0,0 @@
package ovpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
func (p *Provider) GetConnection(selection settings.ServerSelection, ipv6Supported bool) (
connection models.Connection, err error,
) {
defaults := utils.NewConnectionDefaults(443, 1194, 9929) //nolint:mnd
return utils.GetConnection(p.Name(),
p.storage, selection, defaults, ipv6Supported, p.connPicker)
}
-126
View File
@@ -1,126 +0,0 @@
package ovpn
import (
"errors"
"net/http"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Ovpn
errTest := errors.New("test error")
testCases := map[string]struct {
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
ipv6Supported bool
connection models.Connection
errWrapped error
errMessage string
}{
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "filtering servers: test error",
},
"default_openvpn_tcp_port": {
filteredServers: []models.Server{
{IPs: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
Protocol: constants.TCP,
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}),
Port: 443,
Protocol: constants.TCP,
},
},
"default_openvpn_udp_port": {
filteredServers: []models.Server{
{IPs: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
Protocol: constants.UDP,
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}),
Port: 1194,
Protocol: constants.UDP,
},
},
"default_wireguard_port": {
filteredServers: []models.Server{
{IPs: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})}, WgPubKey: "x"},
},
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.Wireguard,
IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}),
Port: 9929,
Protocol: constants.UDP,
PubKey: "x",
},
},
"default_multihop_port": {
filteredServers: []models.Server{
{IPs: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})}, WgPubKey: "x", PortsUDP: []uint16{30044}},
},
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.Wireguard,
IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}),
Port: 30044,
Protocol: constants.UDP,
PubKey: "x",
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
client := (*http.Client)(nil)
provider := New(storage, client)
connection, err := provider.GetConnection(testCase.selection, testCase.ipv6Supported)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.connection, connection)
})
}
}
-38
View File
@@ -1,38 +0,0 @@
package ovpn
import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
func (p *Provider) OpenVPNConfig(connection models.Connection,
settings settings.OpenVPN, ipv6Supported bool,
) (lines []string) {
providerSettings := utils.OpenVPNProviderSettings{
AuthUserPass: true,
RemoteCertTLS: true,
Ciphers: []string{
openvpn.AES256gcm,
openvpn.AES256cbc,
openvpn.AES128gcm,
openvpn.Chacha20Poly1305,
},
CAs: []string{
"MIIEfTCCA2WgAwIBAgIJAK2aIWqpLj1/MA0GCSqGSIb3DQEBBQUAMIGFMQswCQYDVQQGEwJTRTESMBAGA1UECBMJU3RvY2tob2xtMRIwEAYDVQQHEwlTdG9ja2hvbG0xHDAaBgNVBAsTE0Zpcm1hIERhdmlkIFdpYmVyZ2gxEzARBgNVBAMTCm92cG4uc2UgY2ExGzAZBgkqhkiG9w0BCQEWDGluZm9Ab3Zwbi5zZTAeFw0xNDA4MTcxODIxMjlaFw0zNDA4MTIxODIxMjlaMIGFMQswCQYDVQQGEwJTRTESMBAGA1UECBMJU3RvY2tob2xtMRIwEAYDVQQHEwlTdG9ja2hvbG0xHDAaBgNVBAsTE0Zpcm1hIERhdmlkIFdpYmVyZ2gxEzARBgNVBAMTCm92cG4uc2UgY2ExGzAZBgkqhkiG9w0BCQEWDGluZm9Ab3Zwbi5zZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMR+aP4GTuZwurZuOA2NYzMfqKyZi/TJcLEPlGTB/b4CWA9bTd8f0pHPrDAZsXIEayxxB58BIFNDNiybnbO15JN/QwlsqmA+aZX6mCSkScs/rRwasM6LDo8iGx+KmYEqAgzziONGbCMnlO+OaarXte7LhZ9X6Z/bryu4xq/i1v3raak13kXsrogtu4iDzxqJE/QhbNOi0yhCdlm5RYQjmlKGdPB9pNTgcakVI4HcngRYMzBlrGin0YkvWCdpx5FrDNeld7BSWrJMNYyvd+buaid0Fu1T9/P/Srj/8AiabKoaDyiGFbZdTnGfK+04lWRvwAmvazpqbUt5Omw634jJDuMCAwEAAaOB7TCB6jAdBgNVHQ4EFgQUEvJcHHcTiDtu7bAyZw+xaqg+xdIwgboGA1UdIwSBsjCBr4AUEvJcHHcTiDtu7bAyZw+xaqg+xdKhgYukgYgwgYUxCzAJBgNVBAYTAlNFMRIwEAYDVQQIEwlTdG9ja2hvbG0xEjAQBgNVBAcTCVN0b2NraG9sbTEcMBoGA1UECxMTRmlybWEgRGF2aWQgV2liZXJnaDETMBEGA1UEAxMKb3Zwbi5zZSBjYTEbMBkGCSqGSIb3DQEJARYMaW5mb0BvdnBuLnNlggkArZohaqkuPX8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAJmID6OyBJbV7ayPPgquojF+FICuDdOfGVKP828cyISxcbVA04VpD0QLYVb0k9pFUx0NbgX2SvRTiFhP7LcyS1HV9s+XLCb2WItPPsrdRTwtqU2n3TlCEzWA3WOcOCtT6JSkv1eelmx1JnP0gYJrDvDvRYBFctwWhtE0bineSQkZwN6980zkknADLAiHpeZSu/AMx7CGTwA6SmoFvpNBmHXDcfe/9ZqbbYfUfyPNe+0JbMrcv1elKi+6wlEkHFaEBphiZwGEbOX1CjUMcQFgW/cIp3n50Eiyx6ktuqimhyb59P4Nw8gqH452tTtE4MM/brA5y0Q0WFBRBojfZIbGWWQ==", //nolint:lll
},
TLSAuth: "81782767e4d59c4464cc5d1896f1cf6015017d53ac62e2e3b94b889e00b2c69ddc01944fe1c6d895b4d80540502eb71910b8d785c9efa9e3182343532adffe1cfbb7bb6eae39c502da2748edf0fb89b8a20b0a1085cc1f06135037881bc0c4ad8f2c0f4f72d2ab466fb54af3d8264c5fddeb0f21aa0ca41863678f5fc4c44de4ca0926b36dfddc42c6f2fabd1694bdc8215b2d223b9c21dc6734c2c778093187afb8c33403b228b9af68b540c284f6d183bcc88bd41d47bd717996e499ce1cbbfa768a9723c19c58314c4d19cfed82e543ee92e73d38ad26d4fbec231c0f9f3b30773a5c87792e9bc7c34e8d7611002ebedd044e48a0f1f96527bfdcc940aa09", //nolint:lll
KeyDirection: "1",
}
if strings.HasSuffix(connection.Hostname, "singapore.ovpn.com") {
providerSettings.TLSCrypt = providerSettings.TLSAuth
providerSettings.TLSAuth = ""
providerSettings.KeyDirection = ""
}
return utils.OpenVPNConfig(providerSettings, connection, settings, ipv6Supported)
}
-28
View File
@@ -1,28 +0,0 @@
package ovpn
import (
"net/http"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/ovpn/updater"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
storage common.Storage
connPicker *utils.ConnectionPicker
common.Fetcher
}
func New(storage common.Storage, client *http.Client) *Provider {
return &Provider{
storage: storage,
connPicker: utils.NewConnectionPicker(),
Fetcher: updater.New(client),
}
}
func (p *Provider) Name() string {
return providers.Ovpn
}
-153
View File
@@ -1,153 +0,0 @@
package updater
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/netip"
"strings"
)
type apiData struct {
Success bool `json:"success"`
DataCenters []apiDataCenter `json:"datacenters"`
}
type apiDataCenter struct {
City string `json:"city"`
CountryName string `json:"country_name"`
Servers []apiServer `json:"servers"`
}
type apiServer struct {
IP netip.Addr `json:"ip"`
Ptr string `json:"ptr"` // hostname
Online bool `json:"online"`
// PublicKey is for the Standard Shared Entry Point
PublicKey string `json:"public_key"`
// PublicKeyIPv4 is for the Public / Dedicated IP Entry Point
PublicKeyIPv4 string `json:"public_key_ipv4"`
WireguardPorts []uint16 `json:"wireguard_ports"`
MultiHopOpenvpnPort uint16 `json:"multihop_openvpn_port"`
MultiHopWireguardPort uint16 `json:"multihop_wireguard_port"`
}
func fetchAPI(ctx context.Context, client *http.Client) (
data apiData, err error,
) {
const url = "https://www.ovpn.com/v2/api/client/entry"
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return data, err
}
response, err := client.Do(request)
if err != nil {
return data, err
}
if response.StatusCode != http.StatusOK {
_ = response.Body.Close()
return data, fmt.Errorf("HTTP response status code is not OK: %d %s",
response.StatusCode, response.Status)
}
decoder := json.NewDecoder(response.Body)
err = decoder.Decode(&data)
if err != nil {
_ = response.Body.Close()
return data, fmt.Errorf("decoding response body: %w", err)
}
err = response.Body.Close()
if err != nil {
return data, fmt.Errorf("closing response body: %w", err)
}
return data, nil
}
func (a *apiDataCenter) validate() (err error) {
conditionalErrors := []conditionalError{
{err: "city is not set", condition: a.City == ""},
{err: "country name is not set", condition: a.CountryName == ""},
{err: "servers array is not set", condition: len(a.Servers) == 0},
}
err = collectErrors(conditionalErrors)
if err != nil {
var dataCenterSetFields []string
if a.CountryName != "" {
dataCenterSetFields = append(dataCenterSetFields, a.CountryName)
}
if a.City != "" {
dataCenterSetFields = append(dataCenterSetFields, a.City)
}
if len(dataCenterSetFields) == 0 {
return err
}
return fmt.Errorf("data center %s: %w",
strings.Join(dataCenterSetFields, ", "), err)
}
for i, server := range a.Servers {
err = server.validate()
if err != nil {
return fmt.Errorf("datacenter %s, %s: server %d of %d: %w",
a.CountryName, a.City, i+1, len(a.Servers), err)
}
}
return nil
}
func (a *apiServer) validate() (err error) {
const defaultWireguardPort = 9929
conditionalErrors := []conditionalError{
{err: "ip address is not set", condition: !a.IP.IsValid()},
{err: "hostname field is not set", condition: a.Ptr == ""},
{err: "public key field is not set", condition: a.PublicKey == ""},
{err: "public key IPv4 field is not set", condition: a.PublicKeyIPv4 == ""},
{err: "wireguard ports array is not set", condition: len(a.WireguardPorts) == 0},
{
err: "wireguard port is not the default 9929",
condition: len(a.WireguardPorts) != 1 || a.WireguardPorts[0] != defaultWireguardPort,
},
{err: "multihop OpenVPN port is not set", condition: a.MultiHopOpenvpnPort == 0},
{err: "multihop WireGuard port is not set", condition: a.MultiHopWireguardPort == 0},
}
err = collectErrors(conditionalErrors)
switch {
case err == nil:
return nil
case a.Ptr != "":
return fmt.Errorf("server %s: %w", a.Ptr, err)
case a.IP.IsValid():
return fmt.Errorf("server %s: %w", a.IP.String(), err)
default:
return err
}
}
type conditionalError struct {
err string
condition bool
}
func collectErrors(conditionalErrors []conditionalError) (err error) {
errs := make([]string, 0, len(conditionalErrors))
for _, conditionalError := range conditionalErrors {
if !conditionalError.condition {
continue
}
errs = append(errs, conditionalError.err)
}
if len(errs) == 0 {
return nil
}
return errors.New(strings.Join(errs, "; "))
}
-118
View File
@@ -1,118 +0,0 @@
package updater
import (
"context"
"errors"
"io"
"net/http"
"net/netip"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_fetchAPI(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
responseStatus int
responseBody io.ReadCloser
data apiData
err error
}{
"http response status not ok": {
responseStatus: http.StatusNoContent,
err: errors.New("HTTP response status code is not OK: 204 No Content"),
},
"nil body": {
responseStatus: http.StatusOK,
err: errors.New("decoding response body: EOF"),
},
"no server": {
responseStatus: http.StatusOK,
responseBody: io.NopCloser(strings.NewReader(`{}`)),
},
"success": {
responseStatus: http.StatusOK,
responseBody: io.NopCloser(strings.NewReader(`{
"success": true,
"datacenters": [
{
"slug": "vienna",
"city": "Vienna",
"country": "AT",
"country_name": "Austria",
"pools": [
"pool-1.prd.at.vienna.ovpn.com"
],
"ping_address": "37.120.212.227",
"servers": [
{
"ip": "37.120.212.227",
"ptr": "vpn44.prd.vienna.ovpn.com",
"name": "VPN44 - Vienna",
"online": true,
"load": 8,
"public_key": "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
"public_key_ipv4": "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
"wireguard_ports": [
9929
],
"multihop_openvpn_port": 20044,
"multihop_wireguard_port": 30044
}
]
}
]
}`)),
data: apiData{
Success: true,
DataCenters: []apiDataCenter{
{CountryName: "Austria", City: "Vienna", Servers: []apiServer{
{
IP: netip.MustParseAddr("37.120.212.227"),
Ptr: "vpn44.prd.vienna.ovpn.com",
Online: true,
PublicKey: "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
PublicKeyIPv4: "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
WireguardPorts: []uint16{9929},
MultiHopOpenvpnPort: 20044,
MultiHopWireguardPort: 30044,
},
}},
},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, r.URL.String(), "https://www.ovpn.com/v2/api/client/entry")
return &http.Response{
StatusCode: testCase.responseStatus,
Status: http.StatusText(testCase.responseStatus),
Body: testCase.responseBody,
}, nil
}),
}
data, err := fetchAPI(ctx, client)
assert.Equal(t, testCase.data, data)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
@@ -1,9 +0,0 @@
package updater
import "net/http"
type roundTripFunc func(r *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
-82
View File
@@ -1,82 +0,0 @@
package updater
import (
"context"
"errors"
"fmt"
"net/netip"
"sort"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
)
func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers []models.Server, err error,
) {
data, err := fetchAPI(ctx, u.client)
if err != nil {
return nil, fmt.Errorf("fetching API: %w", err)
} else if !data.Success {
return nil, errors.New("response success field is false")
}
for dataCenterIndex, dataCenter := range data.DataCenters {
err = dataCenter.validate()
if err != nil {
return nil, fmt.Errorf("validating data center %d of %d: %w",
dataCenterIndex+1, len(data.DataCenters), err)
}
for _, apiServer := range dataCenter.Servers {
if !apiServer.Online {
continue
}
baseServer := models.Server{
Country: dataCenter.CountryName,
City: dataCenter.City,
Hostname: apiServer.Ptr,
IPs: []netip.Addr{apiServer.IP},
}
openVPNServer := baseServer
openVPNServer.VPN = vpn.OpenVPN
openVPNServer.TCP = true
openVPNServer.UDP = true
multiHopOpenVPNServer := openVPNServer
multiHopOpenVPNServer.MultiHop = true
multiHopOpenVPNServer.PortsTCP = []uint16{apiServer.MultiHopOpenvpnPort}
multiHopOpenVPNServer.PortsUDP = []uint16{apiServer.MultiHopOpenvpnPort}
servers = append(servers, openVPNServer, multiHopOpenVPNServer)
wireguardServer := baseServer
wireguardServer.VPN = vpn.Wireguard
wireguardServer.WgPubKey = apiServer.PublicKey
multiHopWireguardServer := wireguardServer
multiHopWireguardServer.MultiHop = true
multiHopWireguardServer.PortsUDP = []uint16{apiServer.MultiHopWireguardPort}
dedicatedWireguardServer := wireguardServer
dedicatedWireguardServer.WgPubKey = apiServer.PublicKeyIPv4
dedicatedWireguardServer.Dedicated = true
dedicatedMultiHopWireguardServer := multiHopWireguardServer
dedicatedMultiHopWireguardServer.WgPubKey = apiServer.PublicKeyIPv4
dedicatedMultiHopWireguardServer.Dedicated = true
servers = append(servers,
wireguardServer,
multiHopWireguardServer,
dedicatedWireguardServer,
dedicatedMultiHopWireguardServer,
)
}
}
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
sort.Sort(models.SortableServers(servers))
return servers, nil
}
@@ -1,228 +0,0 @@
package updater
import (
"context"
"io"
"net/http"
"net/netip"
"strings"
"testing"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Updater_FetchServers(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
// Inputs
minServers int
// From API
responseStatus int
responseBody string
// Output
servers []models.Server
errWrapped error
errMessage string
}{
"http_response_error": {
responseStatus: http.StatusNoContent,
errMessage: "fetching API: HTTP response status code is not OK: 204 No Content",
},
"success_field_false": {
responseStatus: http.StatusOK,
responseBody: `{"success": false}`,
errMessage: "response success field is false",
},
"validation_failed": {
responseStatus: http.StatusOK,
responseBody: `{
"success": true,
"datacenters": [
{
"city": "Vienna",
"servers": [
{}
]
}
]
}`,
errMessage: "validating data center 1 of 1: data center Vienna: country name is not set",
},
"not_enough_servers": {
minServers: 7,
responseStatus: http.StatusOK,
responseBody: `{
"success": true,
"datacenters": [
{
"city": "Vienna",
"country_name": "Austria",
"servers": [
{
"ip": "37.120.212.227",
"ptr": "vpn44.prd.vienna.ovpn.com",
"online": true,
"public_key": "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
"public_key_ipv4": "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
"wireguard_ports": [9929],
"multihop_openvpn_port": 20044,
"multihop_wireguard_port": 30044
}
]
}
]
}`,
errWrapped: common.ErrNotEnoughServers,
// Wireguard + dedicated Wireguard + Wireguard multi-hop +
// dedicated Wireguard multi-hop + OpenVPN + OpenVPN multi-hop
errMessage: "not enough servers found: 6 and expected at least 7",
},
"success": {
minServers: 4,
responseBody: `{
"success": true,
"datacenters": [
{
"slug": "vienna",
"city": "Vienna",
"country": "AT",
"country_name": "Austria",
"pools": [
"pool-1.prd.at.vienna.ovpn.com"
],
"ping_address": "37.120.212.227",
"servers": [
{
"ip": "37.120.212.227",
"ptr": "vpn44.prd.vienna.ovpn.com",
"name": "VPN44 - Vienna",
"online": true,
"load": 8,
"public_key": "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
"public_key_ipv4": "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
"wireguard_ports": [
9929
],
"multihop_openvpn_port": 20044,
"multihop_wireguard_port": 30044
},
{
"ip": "37.120.212.228",
"ptr": "vpn45.prd.vienna.ovpn.com",
"online": false,
"public_key": "r93LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
"public_key_ipv4": "wGbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
"wireguard_ports": [9929],
"multihop_openvpn_port": 20045,
"multihop_wireguard_port": 30045
}
]
}
]
}`,
responseStatus: http.StatusOK,
servers: []models.Server{
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.OpenVPN,
UDP: true,
TCP: true,
},
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.OpenVPN,
UDP: true,
TCP: true,
MultiHop: true,
PortsTCP: []uint16{20044},
PortsUDP: []uint16{20044},
},
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.Wireguard,
WgPubKey: "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
},
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.Wireguard,
WgPubKey: "r83LIc0Q2F8s3dY9x5y17Yz8wTADJc7giW1t5eSmoXc=",
MultiHop: true,
PortsUDP: []uint16{30044},
},
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.Wireguard,
WgPubKey: "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
Dedicated: true,
},
{
Country: "Austria",
City: "Vienna",
Hostname: "vpn44.prd.vienna.ovpn.com",
IPs: []netip.Addr{netip.MustParseAddr("37.120.212.227")},
VPN: vpn.Wireguard,
WgPubKey: "wFbSRyjSXBmkjJodlqz7DoYn3WNDPYFUIXyIUS2QU2A=",
MultiHop: true,
Dedicated: true,
PortsUDP: []uint16{30044},
},
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, r.URL.String(), "https://www.ovpn.com/v2/api/client/entry")
return &http.Response{
StatusCode: testCase.responseStatus,
Status: http.StatusText(testCase.responseStatus),
Body: io.NopCloser(strings.NewReader(testCase.responseBody)),
}, nil
}),
}
updater := &Updater{
client: client,
}
servers, err := updater.FetchServers(ctx, testCase.minServers)
assert.Equal(t, testCase.servers, servers)
if testCase.errMessage == "" {
assert.NoError(t, err)
} else {
assert.Contains(t, err.Error(), testCase.errMessage)
}
if testCase.errWrapped != nil {
assert.ErrorIs(t, err, testCase.errWrapped)
}
})
}
}
-15
View File
@@ -1,15 +0,0 @@
package updater
import (
"net/http"
)
type Updater struct {
client *http.Client
}
func New(client *http.Client) *Updater {
return &Updater{
client: client,
}
}
+68 -61
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"maps"
"net/netip"
"strings"
"time"
@@ -12,14 +13,14 @@ import (
"github.com/qdm12/gluetun/internal/provider/utils"
)
const nonSymmetricPortStart uint16 = 56789
// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway.
func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) (
internalToExternalPorts map[uint16]uint16, err error,
) {
if !objects.CanPortForward {
return nil, errors.New("server does not support port forwarding")
} else if objects.PortsCount == 0 {
return nil, nil //nolint:nilnil
}
client := natpmp.New()
@@ -39,38 +40,75 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
logger := objects.Logger
logger.Debug("gateway external IPv4 address is " + externalIPv4Address.String())
const externalPort = 0
const lifetime = 60 * time.Second
p.internalToExternalPorts = make(map[uint16]uint16, objects.PortsCount)
for i := range objects.PortsCount {
internalPort := nonSymmetricPortStart + i
protoToInternalPort := map[string]uint16{
"udp": 0,
"tcp": 0,
}
protoToExternalPort := maps.Clone(protoToInternalPort)
for protocol := range protoToExternalPort {
_, assignedInternalPort, assignedExternalPort, assignedLifetime, err := client.AddPortMapping(
ctx, objects.Gateway, protocol, internalPort, externalPort, lifetime)
if err != nil {
return nil, fmt.Errorf("adding %d/%d %s port mapping: %w",
i+1, objects.PortsCount, strings.ToUpper(protocol), err)
}
checkLifetime(logger, strings.ToUpper(protocol), lifetime, assignedLifetime)
checkInternalPort(logger, internalPort, assignedInternalPort)
protoToInternalPort[protocol] = assignedInternalPort
protoToExternalPort[protocol] = assignedExternalPort
}
const lifetime = 60 * time.Second
checkInternalPorts(logger, protoToInternalPort["udp"], protoToInternalPort["tcp"])
checkExternalPorts(logger, protoToExternalPort["udp"], protoToExternalPort["tcp"])
p.internalToExternalPorts[protoToInternalPort["tcp"]] = protoToExternalPort["tcp"]
// Only one port can be a symmetric mapping
const internalPort, externalPort = 0, 1
_, assignedExternalPort, err := addPortMappingTCPUDP(ctx,
client, logger, objects.Gateway, internalPort, externalPort, lifetime)
// Note the returned assignedInternalPort is always 0 in this case
if err != nil {
return nil, fmt.Errorf("adding first port mapping: %w", err)
}
p.internalToExternalPorts[assignedExternalPort] = assignedExternalPort
// Extra ports must be non-symmetric, meaning that the internal port is
// different from the external port.
const nonSymmetricPortStart = uint16(56789)
nonSymmetricPortStartMinusOne := nonSymmetricPortStart - 1
if _, ok := p.internalToExternalPorts[nonSymmetricPortStart]; ok {
nonSymmetricPortStartMinusOne++
}
for i := uint16(1); i < objects.PortsCount; i++ {
internalPort := nonSymmetricPortStartMinusOne + i
const externalPort = 0
assignedInternalPort, assignedExternalPort, err := addPortMappingTCPUDP(ctx,
client, logger, objects.Gateway, internalPort, externalPort, lifetime)
if err != nil {
return nil, fmt.Errorf("adding %d/%d port mapping: %w", i+1, objects.PortsCount, err)
}
p.internalToExternalPorts[assignedInternalPort] = assignedExternalPort
}
return maps.Clone(p.internalToExternalPorts), nil
}
func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger utils.Logger,
gateway netip.Addr, internalPort, externalPort uint16, lifetime time.Duration,
) (assignedInternalPort, assignedExternalPort uint16, err error) {
var assignedLifetime time.Duration
protocolToExternalPort := map[string]uint16{
"tcp": 0,
"udp": 0,
}
for _, protocol := range [...]string{"udp", "tcp"} {
protocolStr := strings.ToUpper(protocol)
_, assignedInternalPort, assignedExternalPort, assignedLifetime, err = client.AddPortMapping(
ctx, gateway, protocol, internalPort, externalPort, lifetime)
if err != nil {
return 0, 0, fmt.Errorf("adding %s port mapping: %w", protocolStr, err)
}
protocolToExternalPort[protocol] = assignedExternalPort
checkLifetime(logger, protocolStr, lifetime, assignedLifetime)
if internalPort != assignedInternalPort {
return 0, 0, fmt.Errorf("%s internal port requested as %d but received %d",
protocolStr, internalPort, assignedInternalPort)
} else if externalPort != 0 && externalPort != 1 && externalPort != assignedExternalPort {
return 0, 0, fmt.Errorf("%s external port requested as %d but received %d",
protocolStr, externalPort, assignedExternalPort)
}
}
if protocolToExternalPort["tcp"] != protocolToExternalPort["udp"] {
return 0, 0, fmt.Errorf("TCP and UDP external ports differ: %d and %d",
protocolToExternalPort["tcp"], protocolToExternalPort["udp"])
}
return assignedInternalPort, assignedExternalPort, nil
}
func checkLifetime(logger utils.Logger, protocol string,
requested, actual time.Duration,
) {
@@ -81,27 +119,6 @@ func checkLifetime(logger utils.Logger, protocol string,
}
}
func checkInternalPort(logger utils.Logger, sent, received uint16) {
if sent != received {
logger.Warn(fmt.Sprintf("internal port assigned %d differs from requested internal port %d",
sent, received))
}
}
func checkInternalPorts(logger utils.Logger, udpPort, tcpPort uint16) {
if udpPort != tcpPort {
logger.Warn(fmt.Sprintf("UDP internal port %d differs from TCP internal port %d",
udpPort, tcpPort))
}
}
func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) {
if udpPort != tcpPort {
logger.Warn(fmt.Sprintf("UDP external port %d differs from TCP external port %d",
udpPort, tcpPort))
}
}
func (p *Provider) KeepPortForward(ctx context.Context,
objects utils.PortForwardObjects,
) (err error) {
@@ -117,22 +134,12 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
objects.Logger.Debug("refreshing forwarded ports since 45 seconds have elapsed")
networkProtocols := [...]string{"udp", "tcp"}
const lifetime = 60 * time.Second
for internalPort, externalPort := range p.internalToExternalPorts {
for _, networkProtocol := range networkProtocols {
_, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := client.AddPortMapping(
ctx, objects.Gateway, networkProtocol, internalPort, externalPort, lifetime)
if err != nil {
return fmt.Errorf("adding port mapping: %w", err)
}
checkLifetime(logger, networkProtocol, lifetime, assignedLiftetime)
if externalPort != assignedExternalPort {
return fmt.Errorf("external port changed from %d to %d", externalPort, assignedExternalPort)
} else if internalPort != assignedInternalPort {
return fmt.Errorf("internal port changed from %d (for external port %d) to %d",
internalPort, externalPort, assignedInternalPort)
}
_, _, err := addPortMappingTCPUDP(ctx, client, logger, objects.Gateway, internalPort, externalPort, lifetime)
if err != nil {
return fmt.Errorf("refreshing port mapping for internal port %d and external port %d: %w",
internalPort, externalPort, err)
}
objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", externalPort))
}
-2
View File
@@ -20,7 +20,6 @@ import (
"github.com/qdm12/gluetun/internal/provider/ivpn"
"github.com/qdm12/gluetun/internal/provider/mullvad"
"github.com/qdm12/gluetun/internal/provider/nordvpn"
"github.com/qdm12/gluetun/internal/provider/ovpn"
"github.com/qdm12/gluetun/internal/provider/perfectprivacy"
"github.com/qdm12/gluetun/internal/provider/privado"
"github.com/qdm12/gluetun/internal/provider/privateinternetaccess"
@@ -68,7 +67,6 @@ func NewProviders(storage Storage, timeNow func() time.Time,
providers.Ivpn: ivpn.New(storage, client, updaterWarner, parallelResolver),
providers.Mullvad: mullvad.New(storage, client),
providers.Nordvpn: nordvpn.New(storage, client, updaterWarner),
providers.Ovpn: ovpn.New(storage, client),
providers.Perfectprivacy: perfectprivacy.New(storage, unzipper, updaterWarner),
providers.Privado: privado.New(storage, client, updaterWarner),
providers.PrivateInternetAccess: privateinternetaccess.New(storage, timeNow, client),
+2 -3
View File
@@ -52,6 +52,8 @@ func GetConnection(provider string,
})
protocol := getProtocol(selection)
port := getPort(selection, defaults.OpenVPNTCPPort,
defaults.OpenVPNUDPPort, defaults.WireguardPort)
connections := make([]models.Connection, 0, len(servers))
for _, server := range servers {
@@ -67,9 +69,6 @@ func GetConnection(provider string,
hostname = server.OvpnX509
}
port := getPort(selection, server, defaults.OpenVPNTCPPort,
defaults.OpenVPNUDPPort, defaults.WireguardPort)
connection := models.Connection{
Type: selection.VPN,
IP: ip,
+1 -16
View File
@@ -6,44 +6,29 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
func getPort(selection settings.ServerSelection, server models.Server,
func getPort(selection settings.ServerSelection,
defaultOpenVPNTCP, defaultOpenVPNUDP, defaultWireguard uint16,
) (port uint16) {
switch selection.VPN {
case vpn.Wireguard:
customPort := *selection.Wireguard.EndpointPort
if customPort > 0 {
// Note: servers filtering ensures the custom port is within the
// server ports defined if any is set.
return customPort
}
if len(server.PortsUDP) > 0 {
defaultWireguard = server.PortsUDP[0]
}
checkDefined("Wireguard", defaultWireguard)
return defaultWireguard
default: // OpenVPN
customPort := *selection.OpenVPN.CustomPort
if customPort > 0 {
// Note: servers filtering ensures the custom port is within the
// server ports defined if any is set.
return customPort
}
if selection.OpenVPN.Protocol == constants.TCP {
if len(server.PortsTCP) > 0 {
defaultOpenVPNTCP = server.PortsTCP[0]
}
checkDefined("OpenVPN TCP", defaultOpenVPNTCP)
return defaultOpenVPNTCP
}
if len(server.PortsUDP) > 0 {
defaultOpenVPNUDP = server.PortsUDP[0]
}
checkDefined("OpenVPN UDP", defaultOpenVPNUDP)
return defaultOpenVPNUDP
}
-45
View File
@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
)
@@ -23,7 +22,6 @@ func Test_GetPort(t *testing.T) {
testCases := map[string]struct {
selection settings.ServerSelection
server models.Server
defaultOpenVPNTCP uint16
defaultOpenVPNUDP uint16
defaultWireguard uint16
@@ -50,20 +48,6 @@ func Test_GetPort(t *testing.T) {
defaultWireguard: defaultWireguard,
port: defaultOpenVPNUDP,
},
"OpenVPN_server_port_udp": {
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
OpenVPN: settings.OpenVPNSelection{
CustomPort: uint16Ptr(0),
Protocol: constants.UDP,
},
},
server: models.Server{
PortsUDP: []uint16{1234},
},
defaultOpenVPNUDP: defaultOpenVPNUDP,
port: 1234,
},
"OpenVPN UDP no default port defined": {
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
@@ -104,20 +88,6 @@ func Test_GetPort(t *testing.T) {
},
port: 1234,
},
"OpenVPN_server_port_tcp": {
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
OpenVPN: settings.OpenVPNSelection{
CustomPort: uint16Ptr(0),
Protocol: constants.TCP,
},
},
server: models.Server{
PortsTCP: []uint16{1234},
},
defaultOpenVPNTCP: defaultOpenVPNTCP,
port: 1234,
},
"Wireguard": {
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
@@ -135,19 +105,6 @@ func Test_GetPort(t *testing.T) {
defaultWireguard: defaultWireguard,
port: 1234,
},
"Wireguard_server_port": {
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
Wireguard: settings.WireguardSelection{
EndpointPort: uint16Ptr(0),
},
},
server: models.Server{
PortsUDP: []uint16{1234},
},
defaultWireguard: defaultWireguard,
port: 1234,
},
"Wireguard no default port defined": {
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
@@ -163,7 +120,6 @@ func Test_GetPort(t *testing.T) {
if testCase.panics != "" {
assert.PanicsWithValue(t, testCase.panics, func() {
_ = getPort(testCase.selection,
testCase.server,
testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP,
testCase.defaultWireguard)
@@ -172,7 +128,6 @@ func Test_GetPort(t *testing.T) {
}
port := getPort(testCase.selection,
testCase.server,
testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP,
testCase.defaultWireguard)
+82
View File
@@ -0,0 +1,82 @@
package restrictednet
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"strconv"
"github.com/qdm12/dns/v2/pkg/provider"
)
// Client is a client for making restricted network requests,
// such as opening temporary firewall rules for HTTPS connections.
// It is not meant to be high performance, although it can be used for
// multiple requests and concurrently.
type Client struct {
outboundInterface string
ipv6Supported bool
firewall Firewall
dohServers []provider.DoHServer
}
func New(settings Settings) *Client {
if err := settings.validate(); err != nil {
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
}
dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers))
for i, upstreamResolver := range settings.UpstreamResolvers {
dohServers[i] = upstreamResolver.DoH
}
return &Client{
outboundInterface: settings.DefaultInterface,
ipv6Supported: *settings.IPv6Supported,
firewall: settings.Firewall,
dohServers: dohServers,
}
}
// OpenHTTPSByHostname opens an https connection through the firewall,
// to the hostname which in the format `host:port`. The returned cleanup
// function must be called to remove the temporary firewall rule and close connections.
// It first resolves the domain in hostname using DNS over HTTPS and then opens
// the restricted HTTPS connection to the resolved IP.
func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) (
httpClient *http.Client, cleanup func() error, err error,
) {
host, portStr, err := net.SplitHostPort(hostname)
if err != nil {
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
}
resolvedIPs, err := c.ResolveName(ctx, host)
if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", host)
}
portUint, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, nil, fmt.Errorf("parsing port: %w", err)
} else if portUint == 0 {
return nil, nil, errors.New("destination port cannot be 0")
}
port := uint16(portUint)
errs := make([]error, 0, len(resolvedIPs))
for _, ip := range resolvedIPs {
addrPort := netip.AddrPortFrom(ip, port)
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
if err != nil {
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
continue
}
return httpClient, cleanup, nil
}
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
}
+7
View File
@@ -0,0 +1,7 @@
//go:build integration
package restrictednet
func ptrTo[T any](value T) *T {
return &value
}
+202
View File
@@ -0,0 +1,202 @@
package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"os"
"time"
"github.com/jsimonetti/rtnetlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned [*http.Client] must be used sequentially only, and each request must
// have its response body fully read/discarded and then closed.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
) (httpClient *http.Client, cleanup func() error, err error) {
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const remove = false
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
closeFD(fd)
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
connection, err := connectSourceConnection(ctx, fd, destinationAddrPort)
if err != nil {
const remove = true
_ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
}
dial := makeDial(connection, destinationTLSName)
httpClient = newHTTPSClient(destinationTLSName, dial)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
err := connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing connection: %w", err))
}
const remove = true
err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
const timeout = 5 * time.Second
transport := &http.Transport{
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
},
DialContext: dial,
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func makeDial(connection net.Conn, tlsName string) dialFunc {
_, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String())
if err != nil {
panic(err) // connection remote address should always be in the form "host:port"
}
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
used := false
return func(_ context.Context, network, address string) (net.Conn, error) {
if used {
return nil, errors.New("dial function called more than once")
}
used = true
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unexpected dial network %q", network)
}
if address != expectedAddress {
return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress)
}
return connection, nil
}
}
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
sourceIP, err := sourceIPForDestination(destinationIP)
if err != nil {
return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err)
}
family := constants.AF_INET
if sourceIP.Is6() {
family = constants.AF_INET6
}
fd, err = newTCPSockStream(family)
if err != nil {
return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err)
}
bindAddrPort := netip.AddrPortFrom(sourceIP, 0)
err = bindFD(fd, bindAddrPort)
if err != nil {
closeFD(fd)
return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err)
}
sourceAddr, err = fdToSourceAddr(fd)
if err != nil {
closeFD(fd)
return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err)
}
return fd, sourceAddr, nil
}
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
connection net.Conn, err error,
) {
err = connectFD(ctx, fd, destinationAddrPort)
if err != nil {
closeFD(fd)
return nil, fmt.Errorf("connecting socket: %w", err)
}
file := os.NewFile(uintptr(fd), "")
if file == nil {
closeFD(fd)
return nil, fmt.Errorf("creating socket file")
}
defer file.Close()
connection, err = net.FileConn(file)
if err != nil {
return nil, fmt.Errorf("wrapping socket connection: %w", err)
}
return connection, nil
}
func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return netip.Addr{}, err
}
defer conn.Close()
family := uint8(constants.AF_INET)
if destinationIP.Is6() {
family = constants.AF_INET6
}
requestMessage := &rtnetlink.RouteMessage{
Family: family,
Attributes: rtnetlink.RouteAttributes{
Dst: destinationIP.AsSlice(),
},
}
messages, err := conn.Route.Get(requestMessage)
if err != nil {
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err)
}
for _, message := range messages {
if message.Attributes.Src == nil {
continue
}
if message.Attributes.Src.To4() == nil {
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
}
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
}
return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP)
}
@@ -0,0 +1,117 @@
//go:build integration
package restrictednet
import (
"context"
"fmt"
"io"
"net/http"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type listenAddrPortMatcher struct {
expected netip.AddrPort
}
func (m listenAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
}
func (m listenAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "is a valid netip.AddrPort with a valid IP and non-zero port"
}
type destinationAddrPortMatcher struct {
expected netip.AddrPort
}
func (m destinationAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Port() == m.expected.Port()
}
func (m destinationAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "matches the port " + fmt.Sprint(m.expected.Port())
}
func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
const destinationTLSName = "one.one.one.one"
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
firewall := NewMockFirewall(ctrl)
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
).Return(nil)
const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()}
settings := Settings{
Firewall: firewall,
DefaultInterface: "eth0",
IPv6Supported: ptrTo(ipv6Supported),
UpstreamResolvers: upstreamResolvers,
}
client := New(settings)
httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)
const requests = 2
for range requests {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
require.NoError(t, err)
response, err := httpClient.Do(request)
require.NoError(t, err)
_, err = io.Copy(io.Discard, response.Body)
require.NoError(t, err)
err = response.Body.Close()
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)
}
err = cleanup()
require.NoError(t, err)
}
+12
View File
@@ -0,0 +1,12 @@
package restrictednet
import (
"context"
"net/netip"
)
type Firewall interface {
AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error
}
@@ -0,0 +1,3 @@
package restrictednet
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
+50
View File
@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
// Package restrictednet is a generated GoMock package.
package restrictednet
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutputFromIPPortToIPPort mocks base method.
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
}
+205
View File
@@ -0,0 +1,205 @@
package restrictednet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"strconv"
"github.com/miekg/dns"
)
// ResolveName resolves the given host name to IP addresses using DoH servers,
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
// The host must be a single well-formed domain name, without port or path.
func (c *Client) ResolveName(ctx context.Context, host string) (
resolvedAddresses []netip.Addr, err error,
) {
const maxTypes = 2
questionTypes := make([]uint16, 0, maxTypes)
if c.ipv6Supported {
questionTypes = append(questionTypes, dns.TypeAAAA)
}
questionTypes = append(questionTypes, dns.TypeA)
var addresses []netip.Addr
errs := make([]error, 0, len(questionTypes))
for _, questionType := range questionTypes {
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
if err != nil {
errs = append(errs, err)
continue
}
addresses = append(addresses, answerAddresses...)
}
switch {
case len(addresses) > 0:
return addresses, nil
case len(errs) == 0:
return nil, nil // no address found
default: // errors
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
}
}
func (c *Client) resolveOneQuestionType(ctx context.Context,
host string, questionType uint16,
) (addresses []netip.Addr, err error) {
queryMessage := &dns.Msg{}
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
queryWire, err := queryMessage.Pack()
if err != nil {
return nil, fmt.Errorf("packing DNS query: %w", err)
}
// Try every DoH server and every of each of their IP until we get a non-empty
// successful response.
errs := make([]error, 0)
for _, dohServer := range c.dohServers {
dohURL, err := url.Parse(dohServer.URL)
if err != nil {
errs = append(errs,
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
continue
}
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
if c.ipv6Supported {
// Prefer IPv6 addresses if IPv6 is supported
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
}
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
for _, dohServerIP := range dohServerIPs {
const defaultDoHPort uint16 = 443
port := defaultDoHPort
if portStr := dohURL.Port(); portStr != "" {
port, err = parseDestinationPort(portStr)
if err != nil {
errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err))
continue
}
}
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port)
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
switch {
case err != nil:
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
dohServer.URL, dohServerAddrPort, err))
continue
case responseMessage.Rcode != dns.RcodeSuccess:
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
continue
}
addresses := answersToNetipAddrs(responseMessage)
if len(addresses) == 0 {
continue
}
return addresses, nil
}
}
if len(errs) == 0 {
return nil, nil
}
return nil, fmt.Errorf("resolving %s %s: %w",
dns.TypeToString[questionType], host, errors.Join(errs...))
}
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerAddrPort netip.AddrPort,
) (responseMessage *dns.Msg, err error) {
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err)
}
defer func() {
closeErr := cleanup()
if err == nil && closeErr != nil {
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
}
}()
requestBody := bytes.NewReader(queryWire)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
request.Header.Set("Content-Type", "application/dns-message")
request.Header.Set("Accept", "application/dns-message")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
responseData, err := io.ReadAll(response.Body)
if err != nil {
_ = response.Body.Close()
return nil, fmt.Errorf("reading response body: %w", err)
}
err = response.Body.Close()
if err != nil {
return nil, fmt.Errorf("closing response body: %w", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response status code is %s (data length %d)",
response.Status, len(responseData))
}
responseMessage = new(dns.Msg)
err = responseMessage.Unpack(responseData)
if err != nil {
return nil, fmt.Errorf("parsing DoH response: %w", err)
}
return responseMessage, nil
}
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
if message == nil {
return nil
}
addresses = make([]netip.Addr, 0, len(message.Answer))
for _, answer := range message.Answer {
switch record := answer.(type) {
case *dns.A:
address, ok := netip.AddrFromSlice(record.A)
if ok {
addresses = append(addresses, address.Unmap())
}
case *dns.AAAA:
address, ok := netip.AddrFromSlice(record.AAAA)
if ok {
addresses = append(addresses, address)
}
}
}
return addresses
}
func parseDestinationPort(portStr string) (port uint16, err error) {
portUint, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return 0, err
}
const maxPortUint = 65535
switch {
case portUint == 0:
return 0, errors.New("port cannot be 0")
case portUint > maxPortUint:
return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint)
}
return uint16(portUint), nil
}
@@ -0,0 +1,110 @@
//go:build integration
package restrictednet
import (
"context"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Client_ResolveName(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
sourceMatcher := listenAddrPortMatcher{}
destinationMatcher := destinationAddrPortMatcher{
expected: netip.AddrPortFrom(netip.Addr{}, 443),
}
// Add rule
firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
).DoAndReturn(func(
_ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
destinationMatcher.expected = destination
return nil
})
// Removal rule
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
).Return(nil).After(firstCall)
settings := Settings{
DefaultInterface: "eth0",
IPv6Supported: ptrTo(false),
Firewall: firewall,
UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
}
client := New(settings)
addresses, err := client.ResolveName(ctx, "github.com")
require.NoError(t, err)
assert.NotEmpty(t, addresses)
}
func Test_answersToNetipAddrs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
message *dns.Msg
expected []netip.Addr
}{
"nil_message": {},
"no_answers": {
message: &dns.Msg{},
expected: []netip.Addr{},
},
"a_record": {
message: &dns.Msg{Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
}},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
},
"aaaa_record": {
message: &dns.Msg{Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
}},
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
},
"mixed_records": {
message: &dns.Msg{Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
}},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
},
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
addresses := answersToNetipAddrs(testCase.message)
assert.Equal(t, testCase.expected, addresses)
})
}
}
+28
View File
@@ -0,0 +1,28 @@
package restrictednet
import (
"errors"
"github.com/qdm12/dns/v2/pkg/provider"
)
type Settings struct {
DefaultInterface string
IPv6Supported *bool
Firewall Firewall
UpstreamResolvers []provider.Provider
}
func (s *Settings) validate() error {
switch {
case s.DefaultInterface == "":
return errors.New("default interface is not set")
case s.IPv6Supported == nil:
return errors.New("IPv6 support field is not set")
case s.Firewall == nil:
return errors.New("firewall is not set")
case len(s.UpstreamResolvers) == 0:
return errors.New("no upstream resolvers provided")
}
return nil
}
+121
View File
@@ -0,0 +1,121 @@
//go:build !windows
package restrictednet
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"golang.org/x/sys/unix"
)
func closeFD(fd int) {
unix.Close(fd)
}
func newTCPSockStream(family int) (fd int, err error) {
fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP)
if err != nil {
return 0, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
_ = unix.Close(fd)
return 0, err
}
return fd, nil
}
func bindFD(fd int, address netip.AddrPort) error {
bindAddr := makeSockAddr(address)
return unix.Bind(fd, bindAddr)
}
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
err := unix.Connect(fd, makeSockAddr(destination))
switch {
case err == nil:
return nil
case !errors.Is(err, unix.EINPROGRESS):
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
bitsIndex := fd / 64 //nolint:mnd
if bitsIndex >= len(unix.FdSet{}.Bits) {
return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd)
}
wset := &unix.FdSet{}
wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
eset := &unix.FdSet{}
eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
const selectTimeout = 50 * time.Millisecond
timeval := unix.NsecToTimeval(int64(selectTimeout))
// Wait for the FD to become writable or hit an error state
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue // Syscall interrupted, try again
}
return fmt.Errorf("select error: %w", err)
} else if n == 0 {
continue // no status change yet
}
// Check if the socket encountered an error
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
if err != nil {
return fmt.Errorf("getsockopt error: %w", err)
} else if n != 0 {
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
}
return nil
}
}
}
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
sockAddr, err := unix.Getsockname(fd)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err)
}
sourceAddrPort, err = sockAddrToAddrPort(sockAddr)
if err != nil {
return netip.AddrPort{}, err
}
return sourceAddrPort, nil
}
func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr {
if addressPort.Addr().Is4() {
return &unix.SockaddrInet4{
Port: int(addressPort.Port()),
Addr: addressPort.Addr().As4(),
}
}
return &unix.SockaddrInet6{
Port: int(addressPort.Port()),
Addr: addressPort.Addr().As16(),
}
}
func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) {
switch typedSockAddr := sockAddr.(type) {
case *unix.SockaddrInet4:
return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
case *unix.SockaddrInet6:
return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
default:
return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr)
}
}
+28
View File
@@ -0,0 +1,28 @@
//go:build windows
package restrictednet
import (
"context"
"net/netip"
)
func closeFD(fd int) {
panic("not implemented")
}
func newTCPSockStream(family int) (fd int, err error) {
panic("not implemented")
}
func bindFD(fd int, address netip.AddrPort) error {
panic("not implemented")
}
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
panic("not implemented")
}
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
panic("not implemented")
}
-3
View File
@@ -43,7 +43,6 @@ type cmdType byte
const (
connect cmdType = 1
bind cmdType = 2
udpAssociate cmdType = 3
)
@@ -51,8 +50,6 @@ func (c cmdType) String() string {
switch c {
case connect:
return "connect"
case bind:
return "bind"
case udpAssociate:
return "UDP associate"
default:
+1 -1
View File
@@ -10,7 +10,7 @@ import (
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) {
_, err := writer.Write([]byte{
socksVersion,
byte(reply),
+89 -49
View File
@@ -2,6 +2,7 @@ package socks5
import (
"context"
"errors"
"fmt"
"net"
"sync"
@@ -15,12 +16,13 @@ type server struct {
logger Logger
// internal fields
listener net.Listener
tcpListener net.Listener
udpRouter *udpRouter
listening atomic.Bool
socksConnCtx context.Context //nolint:containedctx
socksConnCancel context.CancelFunc
done <-chan struct{}
stopping atomic.Bool
done <-chan error
stopCh chan<- struct{}
}
func newServer(settings Settings) *server {
@@ -39,19 +41,28 @@ func (s *server) String() string {
func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
config := &net.ListenConfig{}
s.listener, err = config.Listen(ctx, "tcp", s.address)
s.tcpListener, err = config.Listen(ctx, "tcp", s.address)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
return nil, fmt.Errorf("TCP listening on %s: %w", s.address, err)
}
s.udpRouter, err = newUDPRouter(ctx, s.address, s.logger)
if err != nil {
_ = s.tcpListener.Close()
return nil, fmt.Errorf("creating UDP router: %w", err)
}
s.listening.Store(true)
s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr())
s.logger.Infof("SOCKS5 TCP server listening on %s", s.tcpListener.Addr())
s.logger.Infof("SOCKS5 UDP server listening on %s", s.udpRouter.localAddress())
ready := make(chan struct{})
runErrCh := make(chan error)
runErr = runErrCh
done := make(chan struct{})
done := make(chan error)
s.done = done
go s.runServer(ready, runErrCh, done)
stop := make(chan struct{})
s.stopCh = stop
go s.runServer(ready, runErrCh, stop, done)
select {
case <-ready:
case <-ctx.Done():
@@ -62,61 +73,90 @@ func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
}
func (s *server) runServer(ready chan<- struct{},
runErrCh chan<- error, done chan<- struct{},
runErrCh chan<- error, stop <-chan struct{}, done chan<- error,
) {
close(ready)
defer close(done)
wg := new(sync.WaitGroup)
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.listener.Accept()
if err != nil {
if !s.stopping.Load() {
_ = s.stop()
runErrCh <- fmt.Errorf("accepting connection: %w", err)
}
return
}
wg.Add(1)
go func(ctx context.Context, connection net.Conn,
dialer *net.Dialer, wg *sync.WaitGroup,
) {
defer wg.Done()
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
logger: s.logger,
}
err := socksConn.run(ctx)
udpErrCh := make(chan error)
go func() {
udpErrCh <- s.udpRouter.run(s.socksConnCtx)
}()
tcpErrCh := make(chan error)
go func() {
var wg sync.WaitGroup
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.tcpListener.Accept()
if err != nil {
s.logger.Infof("running socks connection: %s", err)
s.socksConnCancel() // stop ongoing TCP socks connections - no impact on UDP
tcpErrCh <- fmt.Errorf("accepting connection: %w", err)
return
}
}(s.socksConnCtx, connection, dialer, wg)
wg.Go(func() {
connection := connection // capture loop variable
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
udpRouter: s.udpRouter,
logger: s.logger,
}
err := socksConn.run(s.socksConnCtx)
if err != nil {
s.logger.Infof("running socks connection: %s", err)
}
})
}
}()
select {
case <-stop:
s.listening.Store(false)
var errs []error
err := s.tcpListener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing TCP listener: %w", err))
}
// stop ongoing TCP socks connections. This impacts the udpRouter run error when it is being closed.
s.socksConnCancel()
<-tcpErrCh // wait for TCP server to stop
err = s.udpRouter.close()
if err != nil {
errs = append(errs, fmt.Errorf("closing UDP router: %w", err))
}
<-udpErrCh // wait for UDP router to stop
if len(errs) > 0 {
// Only write to the done channel if the [server.Stop] method is waiting to read from it
done <- errors.Join(errs...)
}
// If no error, the done channel is closed so the error is effectively `nil`
// Note: do NOT write an error the runError channel, since we are stopping the server gracefully.
case err := <-udpErrCh:
_ = s.tcpListener.Close() // stop accepting new TCP connections
s.socksConnCancel() // stop ongoing TCP socks connections
<-tcpErrCh // wait for TCP server to stop
runErrCh <- fmt.Errorf("running UDP router: %w", err)
case err := <-tcpErrCh:
s.socksConnCancel()
_ = s.udpRouter.close() // stop UDP router
<-udpErrCh // wait for UDP router to stop
runErrCh <- fmt.Errorf("running TCP server: %w", err)
}
}
func (s *server) Stop() (err error) {
s.stopping.Store(true)
err = s.stop()
<-s.done // wait for run goroutine to finish
s.stopping.Store(false)
return err
}
func (s *server) stop() error {
s.listening.Store(false)
err := s.listener.Close()
s.socksConnCancel() // stop ongoing socks connections
return err
close(s.stopCh)
return <-s.done
}
func (s *server) listeningAddress() net.Addr {
if s.listening.Load() {
return s.listener.Addr()
return s.tcpListener.Addr()
}
return nil
}
+249 -1
View File
@@ -10,6 +10,7 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
)
var (
@@ -23,6 +24,7 @@ type socksConn struct {
username string
password string
clientConn net.Conn
udpRouter *udpRouter
logger Logger
}
@@ -109,11 +111,29 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
if request.command != connect {
switch request.command {
case connect:
err = c.handleConnectRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
case udpAssociate:
err = c.handleUDPAssociateRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
default:
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
return fmt.Errorf("command %s is not supported", request.command)
}
}
func (c *socksConn) handleConnectRequest(ctx context.Context,
socksVersion byte, request request,
) error {
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
if err != nil {
@@ -176,6 +196,234 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
}
}
func (c *socksConn) handleUDPAssociateRequest(ctx context.Context,
socksVersion byte, request request,
) error {
expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported)
return fmt.Errorf("deriving expected client address and port from request: %w", err)
}
bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses()
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("getting udp association addresses: %w", err)
}
association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("registering udp association: %w", err)
}
defer c.udpRouter.unregisterAssociation(association)
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded,
bindAddrType, bindAddress, bindPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("writing successful %s response: %w", udpAssociate, err)
}
associationCtx, associationCancel := context.WithCancel(ctx)
defer associationCancel()
var wg sync.WaitGroup
wg.Go(func() {
c.udpRouter.runAssociationHandler(associationCtx, association)
})
wg.Go(func() {
_, _ = io.Copy(io.Discard, c.clientConn)
associationCancel()
})
<-associationCtx.Done()
wg.Wait()
return nil
}
func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) {
switch request.addressType {
case ipv4, ipv6:
expectedClientAddress, parseErr := netip.ParseAddr(request.destination)
if parseErr != nil {
return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr)
}
expectedClientAddress = expectedClientAddress.Unmap()
if !expectedClientAddress.IsUnspecified() {
return netip.AddrPortFrom(expectedClientAddress, request.port), nil
}
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
case domainName:
// For UDP associate, client endpoint matching is based on observed UDP source
// address/port. A hostname is not directly matchable at this stage, so we
// ignore the domain name request destination entirely.
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
default:
return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType)
}
}
func (c *socksConn) udpAssociationAddresses() (bindAddress string,
bindPort uint16, bindAddrType addrType, err error,
) {
localAddress := c.udpRouter.localAddress().String()
host, portString, err := net.SplitHostPort(localAddress)
if err != nil {
return "", 0, 0, fmt.Errorf("splitting local address: %w", err)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return "", 0, 0, fmt.Errorf("parsing local port: %w", err)
}
bindAddress = host
bindPort = uint16(port)
if isUnspecifiedIPAddress(bindAddress) {
controlLocalAddress := c.clientConn.LocalAddr().String()
controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress)
if splitErr != nil {
return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr)
}
bindAddress = controlLocalHost
}
ipAddress := net.ParseIP(bindAddress)
if ipAddress == nil {
bindAddrType = domainName
return bindAddress, bindPort, bindAddrType, nil
}
if ipAddress.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
return bindAddress, bindPort, bindAddrType, nil
}
func isUnspecifiedIPAddress(address string) bool {
ipAddress, err := netip.ParseAddr(address)
if err != nil {
return false
}
return ipAddress.IsUnspecified()
}
func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) {
const minimumPacketLength = 4
if len(packet) < minimumPacketLength {
return "", nil, fmt.Errorf("packet is too short: %d", len(packet))
}
if packet[0] != 0 || packet[1] != 0 {
return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1])
}
if packet[2] != 0 {
return "", nil, fmt.Errorf("fragmentation is not supported")
}
offset := 3
addressType := addrType(packet[offset])
offset++
switch addressType {
case ipv4:
const ipv4Length = 4
if len(packet) < offset+ipv4Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv4 address")
}
var ip [ipv4Length]byte
copy(ip[:], packet[offset:offset+ipv4Length])
destination = netip.AddrFrom4(ip).String()
offset += ipv4Length
case ipv6:
const ipv6Length = 16
if len(packet) < offset+ipv6Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv6 address")
}
var ip [ipv6Length]byte
copy(ip[:], packet[offset:offset+ipv6Length])
destination = netip.AddrFrom16(ip).String()
offset += ipv6Length
case domainName:
if len(packet) < offset+1 {
return "", nil, fmt.Errorf("packet is too short for domain name length")
}
domainNameLength := int(packet[offset])
offset++
if len(packet) < offset+domainNameLength+2 {
return "", nil, fmt.Errorf("packet is too short for domain name")
}
destination = string(packet[offset : offset+domainNameLength])
offset += domainNameLength
default:
return "", nil, fmt.Errorf("address type is not supported: %d", addressType)
}
port := binary.BigEndian.Uint16(packet[offset : offset+2])
destination = net.JoinHostPort(destination, fmt.Sprint(port))
offset += 2
payload = packet[offset:]
return destination, payload, nil
}
func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort,
payload []byte,
) error {
address := sourceAddrPort.Addr()
if !address.IsValid() {
return errors.New("source address is not valid")
}
err := writeUDPDatagramSourceAddress(writer, address)
if err != nil {
return fmt.Errorf("writing source address: %w", err)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port())
_, err = writer.Write(portBytes[:])
if err != nil {
return fmt.Errorf("writing destination port: %w", err)
}
_, err = writer.Write(payload)
if err != nil {
return fmt.Errorf("writing payload: %w", err)
}
return nil
}
func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error {
var addrType addrType
var addressBytes []byte
switch {
case address.Is4():
addrType = ipv4
array := address.As4()
addressBytes = array[:]
case address.Is6():
addrType = ipv6
array := address.As16()
addressBytes = array[:]
default:
return fmt.Errorf("address type is not supported: %v", address)
}
_, err := writer.Write([]byte{0, 0, 0, byte(addrType)})
if err != nil {
return fmt.Errorf("writing header: %w", err)
}
_, err = writer.Write(addressBytes)
if err != nil {
return fmt.Errorf("writing IP address: %w", err)
}
return nil
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
const headerLength = 2 // version + nMethods bytes
+502 -38
View File
@@ -2,9 +2,13 @@ package socks5
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"testing"
@@ -96,6 +100,178 @@ func TestServerProxy(t *testing.T) {
}
}
func TestServerProxyTCPAndUDPParallel(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
username string
password string
}{
"no_auth": {},
"with_auth": {
username: "user",
password: "pass",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
backendTCPListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
backendTCPConnChannel := make(chan net.Conn, 1)
go func() {
connection, err := backendTCPListener.Accept()
if err != nil {
return
}
backendTCPConnChannel <- connection
}()
backendUDPPacketConn, err := (&net.ListenConfig{}).ListenPacket(t.Context(), "udp", "127.0.0.1:0")
require.NoError(t, err)
server := newServer(Settings{
Username: testCase.username,
Password: testCase.password,
Address: "127.0.0.1:0",
Logger: noopLogger{},
})
_, err = server.Start(t.Context())
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
_ = backendTCPListener.Close()
_ = backendUDPPacketConn.Close()
})
clientTCPConn := dialSOCKS5(t, server.listeningAddress().String(),
backendTCPListener.Addr().String(), testCase.username, testCase.password)
defer clientTCPConn.Close()
backendTCPConn := <-backendTCPConnChannel
defer backendTCPConn.Close()
udpControlConn, clientUDPConn := dialSOCKS5UDPAssociate(t,
server.listeningAddress().String(), testCase.username, testCase.password)
defer udpControlConn.Close()
defer clientUDPConn.Close()
tcpErrCh := make(chan error, 1)
go func() {
tcpErrCh <- runTCPProxyRoundTrip(clientTCPConn, backendTCPConn)
}()
udpErrCh := make(chan error, 1)
go func() {
udpErrCh <- runUDPProxyRoundTrip(t.Context(), clientUDPConn, backendUDPPacketConn)
}()
err = <-tcpErrCh
require.NoError(t, err)
err = <-udpErrCh
require.NoError(t, err)
})
}
}
func runTCPProxyRoundTrip(clientTCPConn net.Conn, backendTCPConn net.Conn) error {
clientMessage := []byte("hello from client")
_, err := clientTCPConn.Write(clientMessage)
if err != nil {
return err
}
received := make([]byte, len(clientMessage))
_, err = io.ReadFull(backendTCPConn, received)
if err != nil {
return err
}
if !bytes.Equal(clientMessage, received) {
return errors.New("backend did not receive expected TCP payload")
}
backendMessage := []byte("hello from backend")
_, err = backendTCPConn.Write(backendMessage)
if err != nil {
return err
}
receivedByClient := make([]byte, len(backendMessage))
_, err = io.ReadFull(clientTCPConn, receivedByClient)
if err != nil {
return err
}
if !bytes.Equal(backendMessage, receivedByClient) {
return errors.New("client did not receive expected TCP payload")
}
return nil
}
func runUDPProxyRoundTrip(ctx context.Context, clientUDPConn *net.UDPConn, backendUDPPacketConn net.PacketConn) error {
udpPayload := []byte("hello from udp client")
udpRequest, err := makeSOCKS5UDPDatagram(backendUDPPacketConn.LocalAddr().String(), udpPayload)
if err != nil {
return err
}
_, err = clientUDPConn.Write(udpRequest)
if err != nil {
return err
}
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
err = backendUDPPacketConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on backend connection: %w", err)
}
}
const bufferSize = 512
backendReadBuffer := make([]byte, bufferSize)
packetLength, proxyAddress, err := backendUDPPacketConn.ReadFrom(backendReadBuffer)
if err != nil {
return err
}
if !bytes.Equal(udpPayload, backendReadBuffer[:packetLength]) {
return errors.New("backend did not receive expected UDP payload")
}
backendUDPReply := []byte("hello from udp backend")
_, err = backendUDPPacketConn.WriteTo(backendUDPReply, proxyAddress)
if err != nil {
return err
}
if hasDeadline {
err = clientUDPConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on client connection: %w", err)
}
}
udpResponseBuffer := make([]byte, 1024)
responseLength, err := clientUDPConn.Read(udpResponseBuffer)
if err != nil {
return err
}
destinationAddress, udpResponsePayload, err := parseSOCKS5UDPDatagram(udpResponseBuffer[:responseLength])
if err != nil {
return err
}
if !bytes.Equal(backendUDPReply, udpResponsePayload) {
return errors.New("client did not receive expected UDP payload")
}
if destinationAddress != backendUDPPacketConn.LocalAddr().String() {
return errors.New("udp response destination address mismatch")
}
return nil
}
// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password
// subnegotiation) and returns a connected net.Conn ready for data exchange.
func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn {
@@ -109,6 +285,55 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, conn, username, password)
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
_, err = readSOCKS5ResponseAddress(t, conn)
require.NoError(t, err)
return conn
}
func dialSOCKS5UDPAssociate(t *testing.T, proxyAddr, username, password string) (net.Conn, *net.UDPConn) {
t.Helper()
controlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, controlConn, username, password)
udpAssociateRequest := []byte{socks5Version, byte(udpAssociate), 0, byte(ipv4), 0, 0, 0, 0, 0, 0}
_, err = controlConn.Write(udpAssociateRequest)
require.NoError(t, err)
udpProxyAddress, err := readSOCKS5ResponseAddress(t, controlConn)
require.NoError(t, err)
udpProxyResolvedAddress, err := net.ResolveUDPAddr("udp", udpProxyAddress)
require.NoError(t, err)
udpConn, err := net.DialUDP("udp", nil, udpProxyResolvedAddress)
require.NoError(t, err)
return controlConn, udpConn
}
func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) {
t.Helper()
var err error
var method authMethod
if username != "" || password != "" {
method = authUsernamePassword
@@ -138,45 +363,146 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0])
require.Equal(t, byte(0), subnegResp[1])
}
}
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
func readSOCKS5ResponseAddress(t *testing.T, conn net.Conn) (address string, err error) {
t.Helper()
var responseHeader [4]byte
_, err = io.ReadFull(conn, responseHeader[:])
require.NoError(t, err)
require.Equal(t, socks5Version, responseHeader[0])
require.Equal(t, byte(succeeded), responseHeader[1])
// Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller).
switch addrType(responseHeader[3]) {
case ipv4:
var addrPort [net.IPv4len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case ipv6:
var addrPort [net.IPv6len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case domainName:
var lenBuf [1]byte
_, err = io.ReadFull(conn, lenBuf[:])
require.NoError(t, err)
addrPort := make([]byte, int(lenBuf[0])+2)
_, err = io.ReadFull(conn, addrPort)
require.NoError(t, err)
if err != nil {
return "", err
}
if responseHeader[0] != socks5Version {
return "", errors.New("version mismatch")
}
if responseHeader[1] != byte(succeeded) {
return "", errors.New("request was not successful")
}
return conn
var host string
switch addrType(responseHeader[3]) {
case ipv4:
addressAndPort := make([]byte, net.IPv4len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv4len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv4len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
case ipv6:
addressAndPort := make([]byte, net.IPv6len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv6len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv6len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
case domainName:
var lengthBuffer [1]byte
_, err = io.ReadFull(conn, lengthBuffer[:])
if err != nil {
return "", err
}
domainAndPort := make([]byte, int(lengthBuffer[0])+2)
_, err = io.ReadFull(conn, domainAndPort)
if err != nil {
return "", err
}
host = string(domainAndPort[:len(domainAndPort)-2])
port := binary.BigEndian.Uint16(domainAndPort[len(domainAndPort)-2:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
default:
return "", errors.New("unknown address type")
}
}
func makeSOCKS5UDPDatagram(targetAddress string, payload []byte) ([]byte, error) {
host, portString, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
datagram := []byte{0, 0, 0}
ipAddress := net.ParseIP(host)
if ipAddress != nil {
if ipAddress.To4() != nil {
datagram = append(datagram, byte(ipv4))
datagram = append(datagram, ipAddress.To4()...)
} else {
datagram = append(datagram, byte(ipv6))
datagram = append(datagram, ipAddress.To16()...)
}
} else {
if len(host) > 255 {
return nil, errors.New("domain name too long")
}
datagram = append(datagram, byte(domainName), byte(len(host)))
datagram = append(datagram, []byte(host)...)
}
datagram = binary.BigEndian.AppendUint16(datagram, uint16(port))
datagram = append(datagram, payload...)
return datagram, nil
}
func parseSOCKS5UDPDatagram(datagram []byte) (destinationAddress string, payload []byte, err error) {
if len(datagram) < 4 {
return "", nil, errors.New("datagram too short")
}
if datagram[0] != 0 || datagram[1] != 0 {
return "", nil, errors.New("invalid reserved header")
}
if datagram[2] != 0 {
return "", nil, errors.New("fragments are not supported")
}
offset := 3
var host string
switch addrType(datagram[offset]) {
case ipv4:
offset++
if len(datagram) < offset+net.IPv4len+2 {
return "", nil, errors.New("datagram too short for IPv4")
}
host = net.IP(datagram[offset : offset+net.IPv4len]).String()
offset += net.IPv4len
case ipv6:
offset++
if len(datagram) < offset+net.IPv6len+2 {
return "", nil, errors.New("datagram too short for IPv6")
}
host = net.IP(datagram[offset : offset+net.IPv6len]).String()
offset += net.IPv6len
case domainName:
offset++
if len(datagram) < offset+1 {
return "", nil, errors.New("datagram too short for domain length")
}
domainLength := int(datagram[offset])
offset++
if len(datagram) < offset+domainLength+2 {
return "", nil, errors.New("datagram too short for domain")
}
host = string(datagram[offset : offset+domainLength])
offset += domainLength
default:
return "", nil, errors.New("unknown address type")
}
if len(datagram) < offset+2 {
return "", nil, errors.New("datagram too short for port")
}
port := binary.BigEndian.Uint16(datagram[offset : offset+2])
offset += 2
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), datagram[offset:], nil
}
func Test_newServer(t *testing.T) {
@@ -224,7 +550,8 @@ func Test_Server_StartStop(t *testing.T) {
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 TCP server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 UDP server listening on %s", gomock.Any())
server := newServer(Settings{
Address: "127.0.0.1:0",
@@ -377,6 +704,70 @@ func Test_decodeRequest(t *testing.T) {
}
}
func Test_udpAssociateExpectedClientEndpoint(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
request request
expected netip.AddrPort
expectedErr string
}{
"ipv4_endpoint": {
request: request{
addressType: ipv4,
destination: "192.0.2.10",
port: 5555,
},
expected: netip.MustParseAddrPort("192.0.2.10:5555"),
},
"ipv4_unspecified_address": {
request: request{
addressType: ipv4,
destination: "0.0.0.0",
port: 6000,
},
expected: netip.AddrPortFrom(netip.Addr{}, 6000),
},
"domain_name_with_port": {
request: request{
addressType: domainName,
destination: "client.example",
port: 7000,
},
expected: netip.AddrPortFrom(netip.Addr{}, 7000),
},
"domain_name_without_port": {
request: request{
addressType: domainName,
destination: "client.example",
},
expected: netip.AddrPort{},
},
"unsupported_address_type": {
request: request{
addressType: 255,
},
expectedErr: "address type 255 is not supported",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
result, err := udpAssociateExpectedClientEndpoint(testCase.request)
if testCase.expectedErr != "" {
assert.ErrorContains(t, err, testCase.expectedErr)
return
}
assert.NoError(t, err)
assert.Equal(t, testCase.expected, result)
})
}
}
func Test_verifyFirstNegotiation(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -598,10 +989,6 @@ func Test_cmdType_String(t *testing.T) {
cmd: connect,
expectedName: "connect",
},
"bind": {
cmd: bind,
expectedName: "bind",
},
"udp_associate": {
cmd: udpAssociate,
expectedName: "UDP associate",
@@ -620,3 +1007,80 @@ func Test_cmdType_String(t *testing.T) {
})
}
}
func Test_socksConn_udpAssociationAddresses(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
routerAddress string
expectAddressFromConn bool
expectedAddress string
}{
"wildcard_router_address_uses_control_connection_local_ip": {
routerAddress: ":0",
expectAddressFromConn: true,
},
"concrete_router_address_is_kept": {
routerAddress: "127.0.0.1:0",
expectedAddress: "127.0.0.1",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
router, err := newUDPRouter(t.Context(), testCase.routerAddress, noopLogger{})
require.NoError(t, err)
t.Cleanup(func() {
err := router.close()
assert.NoError(t, err)
})
controlListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err)
})
acceptedConnCh := make(chan net.Conn, 1)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", controlListener.Addr().String())
require.NoError(t, err)
defer clientControlConn.Close()
serverControlConn := <-acceptedConnCh
defer serverControlConn.Close()
socksConnection := &socksConn{
clientConn: clientControlConn,
udpRouter: router,
}
bindAddress, bindPort, bindAddrType, err := socksConnection.udpAssociationAddresses()
require.NoError(t, err)
if testCase.expectAddressFromConn {
clientLocalHost, _, err := net.SplitHostPort(clientControlConn.LocalAddr().String())
require.NoError(t, err)
assert.Equal(t, clientLocalHost, bindAddress)
} else {
assert.Equal(t, testCase.expectedAddress, bindAddress)
}
_, routerPortString, err := net.SplitHostPort(router.localAddress().String())
require.NoError(t, err)
routerPort, err := strconv.ParseUint(routerPortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint16(routerPort), bindPort)
assert.Equal(t, ipv4, bindAddrType)
})
}
}
+370
View File
@@ -0,0 +1,370 @@
package socks5
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
)
type udpAssociation struct {
id uint64
clientAddrPort netip.AddrPort
expectedAddrPort netip.AddrPort
controlConnAddr netip.Addr
packetCh chan *bytes.Buffer
}
type udpRouter struct {
logger Logger
listener net.PacketConn
mutex sync.Mutex
bufferPool sync.Pool
nextAssociationID uint64
clientAddrPortToAssociation map[netip.AddrPort]udpAssociation
clientIPToPendingAssociations map[netip.Addr][]udpAssociation
associationIDToClientAddrPort map[uint64]netip.AddrPort
}
const (
maxUDPPacketLength = 65535
maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2
pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead
)
func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) {
config := &net.ListenConfig{}
listener, err := config.ListenPacket(ctx, "udp", address)
if err != nil {
return nil, fmt.Errorf("UDP listening: %w", err)
}
return &udpRouter{
logger: logger,
listener: listener,
bufferPool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity))
},
},
nextAssociationID: 1,
clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation),
clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation),
associationIDToClientAddrPort: make(map[uint64]netip.AddrPort),
}, nil
}
func (r *udpRouter) localAddress() net.Addr {
return r.listener.LocalAddr()
}
func (r *udpRouter) close() error {
return r.listener.Close()
}
func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) {
controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String())
if err != nil {
return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err)
}
controlConnAddr := controlConnAddrPort.Addr().Unmap()
r.mutex.Lock()
defer r.mutex.Unlock()
const udpPacketChannelBuffer = 2
associationID := r.nextAssociationID
r.nextAssociationID++
association := udpAssociation{
id: associationID,
expectedAddrPort: expectedAddrPort,
controlConnAddr: controlConnAddr,
packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer),
}
if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 {
association.clientAddrPort = expectedAddrPort
r.clientAddrPortToAssociation[association.clientAddrPort] = association
r.associationIDToClientAddrPort[association.id] = association.clientAddrPort
return association, nil
}
pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr]
pendingAssociations = append(pendingAssociations, association)
r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations
return association, nil
}
func (r *udpRouter) unregisterAssociation(association udpAssociation) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id]
if hasClientAddress {
delete(r.associationIDToClientAddrPort, association.id)
delete(r.clientAddrPortToAssociation, clientAddrPort)
}
pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr]
for i, pendingAssociation := range pendingAssociations {
if pendingAssociation.id == association.id {
pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...)
break
}
}
if len(pendingAssociations) == 0 {
delete(r.clientIPToPendingAssociations, association.controlConnAddr)
} else {
r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations
}
}
func (r *udpRouter) run(ctx context.Context) error {
packetBuffer := make([]byte, maxUDPPacketLength)
for {
packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil && errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("reading UDP packet: %w", err)
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
_, err = buffer.Write(packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("buffering packet: %s", err)
continue
}
err = r.routePacket(sourceAddrPort, buffer)
if err != nil {
r.logger.Warnf("failed routing UDP packet: %s", err)
}
}
}
func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error {
r.mutex.Lock()
association, packetFromClient := r.findClientAssociation(sourceAddrPort)
r.mutex.Unlock()
if !packetFromClient {
r.bufferPool.Put(packet)
return nil
}
select {
case association.packetCh <- packet:
return nil
default:
r.bufferPool.Put(packet)
return errors.New("association packet queue full")
}
}
func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) (
association udpAssociation, ok bool,
) {
association, ok = r.clientAddrPortToAssociation[sourceAddrPort]
if ok {
return association, true
}
sourceAddr := sourceAddrPort.Addr()
pendingAssociations := r.clientIPToPendingAssociations[sourceAddr]
if len(pendingAssociations) == 0 {
return udpAssociation{}, false
}
index := -1
for i, pendingAssociation := range pendingAssociations {
if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) {
association = pendingAssociation
index = i
break
}
}
if index == -1 {
return udpAssociation{}, false
}
r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...)
if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 {
delete(r.clientIPToPendingAssociations, sourceAddr)
}
association.clientAddrPort = sourceAddrPort
r.clientAddrPortToAssociation[sourceAddrPort] = association
r.associationIDToClientAddrPort[association.id] = sourceAddrPort
return association, true
}
func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool {
switch {
case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr():
return false
case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port():
return false
}
return true
}
func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) (
clientAddrPort netip.AddrPort, ok bool,
) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, ok = r.associationIDToClientAddrPort[associationID]
return clientAddrPort, ok
}
func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) {
config := &net.ListenConfig{}
socket, err := config.ListenPacket(ctx, "udp", ":0")
if err != nil {
r.logger.Warnf("creating per-association UDP socket: %s", err)
return
}
defer socket.Close()
go closeSocketOnContextDone(ctx, socket)
packetBuffer := make([]byte, maxUDPPacketLength)
forwardDoneCh := make(chan struct{})
go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh)
for {
packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
<-forwardDoneCh
return
}
r.logger.Warnf("reading from per-association UDP socket: %s", err)
continue
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address from destination: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("encoding response datagram: %s", err)
continue
}
clientAddrPort, found := r.clientAddrPortForAssociation(association.id)
if !found {
r.bufferPool.Put(buffer)
r.logger.Warnf("client address not found for association id %d", association.id)
continue
}
clientUDPAddress := &net.UDPAddr{
IP: clientAddrPort.Addr().AsSlice(),
Port: int(clientAddrPort.Port()),
}
_, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("writing response to client: %s", err)
}
}
}
func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) {
<-ctx.Done()
_ = socket.Close()
}
func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn,
packetCh <-chan *bytes.Buffer, done chan<- struct{},
) {
defer close(done)
for {
select {
case <-ctx.Done():
return
case buffer, ok := <-packetCh:
if !ok {
return
}
err := r.writeClientPacketToDestination(ctx, socket, buffer)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("forwarding client packet to destination: %s", err)
}
}
}
}
func (r *udpRouter) writeClientPacketToDestination(ctx context.Context,
socket net.PacketConn, packet *bytes.Buffer,
) error {
destination, payload, err := decodeUDPDatagram(packet.Bytes())
if err != nil {
return fmt.Errorf("decoding UDP datagram: %w", err)
}
host, portStr, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Errorf("splitting destination host and port: %w", err)
}
if _, err := netip.ParseAddr(host); err != nil { // domain name
addrs, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return fmt.Errorf("resolving destination host: %w", err)
}
if len(addrs) == 0 {
return fmt.Errorf("resolving destination host: no addresses found for %q", host)
}
destination = net.JoinHostPort(addrs[0], portStr)
}
resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination)
if err != nil {
return fmt.Errorf("resolving destination UDP address: %w", err)
}
_, err = socket.WriteTo(payload, resolvedDestinationUDPAddress)
if err != nil && ctx.Err() == nil {
return fmt.Errorf("writing payload to destination: %w", err)
}
return nil
}
func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) {
addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err)
}
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
@@ -0,0 +1,164 @@
//go:build integration
package socks5
import (
"bytes"
"context"
"math/rand/v2"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_udpRouter_ResolveGithubFromCloudflareDNS(t *testing.T) {
t.Parallel()
ctx := t.Context()
var cancel context.CancelFunc
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
const deadlineBuffer = 500 * time.Millisecond
deadline = deadline.Add(-deadlineBuffer)
} else {
const defaultTimeout = 10 * time.Second
deadline = time.Now().Add(defaultTimeout)
}
ctx, cancel = context.WithDeadline(ctx, deadline)
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
router, err := newUDPRouter(ctx, "127.0.0.1:0", logger)
require.NoError(t, err)
routerRunErrCh := make(chan error)
go func() {
routerRunErrCh <- router.run(ctx)
}()
t.Cleanup(func() {
cancel()
err := router.close()
assert.NoError(t, err, "closing router")
runErr := <-routerRunErrCh
assert.NoError(t, runErr)
})
controlListener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err, "closing control listener")
})
acceptedConnCh := make(chan net.Conn)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
assert.NoError(t, acceptErr, "accepting control connection")
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", controlListener.Addr().String())
require.NoError(t, err)
t.Cleanup(func() {
err = clientControlConn.Close()
assert.NoError(t, err, "closing client control connection")
})
serverControlConn := <-acceptedConnCh
t.Cleanup(func() {
err := serverControlConn.Close()
assert.NoError(t, err, "closing server control connection")
})
association, err := router.registerAssociation(serverControlConn, netip.AddrPort{})
require.NoError(t, err)
t.Cleanup(func() {
router.unregisterAssociation(association)
})
associationCtx, associationCancel := context.WithCancel(ctx)
handlerDoneCh := make(chan struct{})
go func() {
router.runAssociationHandler(associationCtx, association)
close(handlerDoneCh)
}()
t.Cleanup(func() {
associationCancel()
<-handlerDoneCh
})
udpRouterAddress, err := net.ResolveUDPAddr("udp", router.localAddress().String())
require.NoError(t, err)
clientUDPConn, err := net.DialUDP("udp", nil, udpRouterAddress)
require.NoError(t, err)
t.Cleanup(func() {
err := clientUDPConn.Close()
assert.NoError(t, err, "closing client UDP connection")
})
queryID := uint16(rand.Uint32()) //nolint:gosec
dnsRequest := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: queryID,
RecursionDesired: true,
},
Question: []dns.Question{{
Name: dns.Fqdn("github.com"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
dnsQuery, err := dnsRequest.Pack()
require.NoError(t, err)
targetAddrPort := netip.MustParseAddrPort("1.1.1.1:53")
socksDatagramBuffer := bytes.NewBuffer(nil)
err = encodeUDPDatagramToBuffer(socksDatagramBuffer, targetAddrPort, dnsQuery)
require.NoError(t, err)
socksDatagram := socksDatagramBuffer.Bytes()
err = clientUDPConn.SetDeadline(deadline)
require.NoError(t, err)
_, err = clientUDPConn.Write(socksDatagram)
require.NoError(t, err)
responseBuffer := make([]byte, maxUDPPacketLength)
responseLength, err := clientUDPConn.Read(responseBuffer)
require.NoError(t, err)
responseDestination, responsePayload, err := decodeUDPDatagram(responseBuffer[:responseLength])
require.NoError(t, err)
responseHost, responsePortString, err := net.SplitHostPort(responseDestination)
require.NoError(t, err)
responsePort, err := strconv.ParseUint(responsePortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint64(53), responsePort)
assert.NotEmpty(t, responseHost)
dnsResponse := new(dns.Msg)
err = dnsResponse.Unpack(responsePayload)
require.NoError(t, err)
assert.Equal(t, queryID, dnsResponse.Id)
assert.True(t, dnsResponse.Response)
assert.Equal(t, dns.RcodeSuccess, dnsResponse.Rcode)
require.NotEmpty(t, dnsResponse.Question)
assert.Equal(t, dns.Fqdn("github.com"), dnsResponse.Question[0].Name)
assert.Equal(t, dns.TypeA, dnsResponse.Question[0].Qtype)
assert.NotEmpty(t, dnsResponse.Answer)
require.NoError(t, err)
}
+12 -4
View File
@@ -1,15 +1,23 @@
package storage
import (
"slices"
"net/netip"
"github.com/qdm12/gluetun/internal/models"
)
func copyServer(server models.Server) (serverCopy models.Server) {
serverCopy = server
serverCopy.IPs = slices.Clone(server.IPs)
serverCopy.PortsTCP = slices.Clone(server.PortsTCP)
serverCopy.PortsUDP = slices.Clone(server.PortsUDP)
serverCopy.IPs = copyIPs(server.IPs)
return serverCopy
}
func copyIPs(toCopy []netip.Addr) (copied []netip.Addr) {
if toCopy == nil {
return nil
}
copied = make([]netip.Addr, len(toCopy))
copy(copied, toCopy)
return copied
}
+39 -5
View File
@@ -21,9 +21,43 @@ func Test_copyServer(t *testing.T) {
assert.Equal(t, server, serverCopy)
// Check for mutation
serverCopy.IPs[0] = netip.AddrFrom4([4]byte{9, 9, 9, 9})
serverCopy.PortsTCP = []uint16{80}
serverCopy.PortsUDP = []uint16{53}
assert.NotEqual(t, server.IPs, serverCopy.IPs)
assert.NotEqual(t, server.PortsTCP, serverCopy.PortsTCP)
assert.NotEqual(t, server.PortsUDP, serverCopy.PortsUDP)
assert.NotEqual(t, server, serverCopy)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []netip.Addr
copied []netip.Addr
}{
"nil": {},
"empty": {
toCopy: []netip.Addr{},
copied: []netip.Addr{},
},
"single IP": {
toCopy: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})},
copied: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})},
},
"two IPs": {
toCopy: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1}), netip.AddrFrom4([4]byte{2, 2, 2, 2})},
copied: []netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1}), netip.AddrFrom4([4]byte{2, 2, 2, 2})},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
testCase.toCopy[0] = netip.AddrFrom4([4]byte{9, 9, 9, 9})
assert.NotEqual(t, testCase.toCopy[0], testCase.copied[0])
}
})
}
}
-33
View File
@@ -3,7 +3,6 @@ package storage
import (
"errors"
"fmt"
"slices"
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
@@ -49,7 +48,6 @@ func (s *Storage) FilterServers(provider string, selection settings.ServerSelect
return servers, nil
}
//nolint:gocognit,gocyclo
func filterServer(server models.Server,
selection settings.ServerSelection,
) (filtered bool) {
@@ -92,11 +90,6 @@ func filterServer(server models.Server,
return true
}
if (*selection.Dedicated && !server.Dedicated) ||
(!*selection.Dedicated && server.Dedicated) {
return false
}
if filterByPossibilities(server.Country, selection.Countries) {
return true
}
@@ -129,14 +122,6 @@ func filterServer(server models.Server,
return true
}
serverPorts := server.PortsUDP
if server.VPN == vpn.OpenVPN && server.TCP {
serverPorts = server.PortsTCP
}
if filterByPorts(selection, serverPorts) {
return true
}
// TODO filter port forward server for PIA
return false
@@ -180,21 +165,3 @@ func filterByProtocol(selection settings.ServerSelection,
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
}
}
func filterByPorts(selection settings.ServerSelection,
serverPorts []uint16,
) (filtered bool) {
if len(serverPorts) == 0 {
return false
}
customPort := *selection.OpenVPN.CustomPort
if selection.VPN == vpn.Wireguard {
customPort = *selection.Wireguard.EndpointPort
}
if customPort == 0 {
return false
}
return !slices.Contains(serverPorts, customPort)
}
+1 -10
View File
@@ -14,7 +14,7 @@ func commaJoin(slice []string) string {
return strings.Join(slice, ", ")
}
func noServerFoundError(selection settings.ServerSelection) (err error) { //nolint:gocyclo
func noServerFoundError(selection settings.ServerSelection) (err error) {
var messageParts []string
messageParts = append(messageParts, "VPN "+selection.VPN)
@@ -155,15 +155,6 @@ func noServerFoundError(selection settings.ServerSelection) (err error) { //noli
"target ip address "+targetIP.String())
}
customPort := *selection.OpenVPN.CustomPort
if selection.VPN == vpn.Wireguard {
customPort = *selection.Wireguard.EndpointPort
}
if customPort > 0 {
messageParts = append(messageParts,
fmt.Sprintf("%s endpoint port %d", selection.VPN, customPort))
}
message := "for " + strings.Join(messageParts, "; ")
return fmt.Errorf("no server found: %s", message)
+11 -23
View File
@@ -1,39 +1,26 @@
package storage
import (
"embed"
"encoding/json"
"fmt"
"path"
"path/filepath"
serversmodule "github.com/qdm12/gluetun-servers/pkg/servers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
//go:embed servers.json
var allServersEmbedFS embed.FS
func parseHardcodedServers() (allServers models.AllServers) {
f, err := allServersEmbedFS.Open("servers.json")
if err != nil {
panic(err)
}
defer f.Close() // no-op
decoder := json.NewDecoder(f)
err = decoder.Decode(&allServers)
if err != nil {
panic("decoding servers.json: " + err.Error())
}
allProviders := providers.All()
for provider, metadata := range allServers.ProviderToServers {
if metadata.Filepath == "" {
panic(fmt.Sprintf("embedded manifest file servers.json should have the filepath field set for %s", provider))
}
filename := path.Base(metadata.Filepath)
const version = 1
allServers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
allServers.Version = version
for _, provider := range allProviders {
filename := provider + ".json"
providerFile, err := serversmodule.Files.Open(filename)
if err != nil {
const rootURL = "https://github.com/qdm12/gluetun-servers/blob/main/pkg/servers"
panic(fmt.Sprintf("reading embedded provider file defined at %s/%s: %s", rootURL, filename, err))
panic(fmt.Sprintf("reading embedded provider file %s for %s: %s", filename, provider, err))
}
defer providerFile.Close() // no-op
@@ -48,7 +35,8 @@ func parseHardcodedServers() (allServers models.AllServers) {
filename, provider))
}
providerServers.Filepath = metadata.Filepath // inherit filepath from servers.json
const serversPath = "/gluetun/servers/"
providerServers.Filepath = filepath.Join(serversPath, filename)
allServers.ProviderToServers[provider] = providerServers
}
+1 -4
View File
@@ -33,10 +33,7 @@ func Test_parseHardcodedServers(t *testing.T) {
func Test_parseHardcodedServers_filepathsAndEmbeddedProviderFiles(t *testing.T) {
t.Parallel()
var hardcodedServers models.AllServers
require.NotPanics(t, func() {
hardcodedServers = parseHardcodedServers()
})
hardcodedServers := parseHardcodedServers()
allProviders := providers.All()
for _, provider := range allProviders {
+1 -2
View File
@@ -3,6 +3,5 @@ package storage
import "fmt"
func panicOnProviderMissingHardcoded(provider string) {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map; "+
"did you add the provider key in the embedded servers.json?", provider))
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
}
+1 -2
View File
@@ -152,8 +152,7 @@ func Test_extractServersFromBytes(t *testing.T) {
allProviders[0]: 1,
// Missing provider allProviders[1]
}
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map; "+
"did you add the provider key in the embedded servers.json?", allProviders[1])
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcodedVersions)
})
-75
View File
@@ -1,75 +0,0 @@
{
"version": 1,
"airvpn": {
"filepath": "/gluetun/servers/airvpn.json"
},
"cyberghost": {
"filepath": "/gluetun/servers/cyberghost.json"
},
"expressvpn": {
"filepath": "/gluetun/servers/expressvpn.json"
},
"fastestvpn": {
"filepath": "/gluetun/servers/fastestvpn.json"
},
"giganews": {
"filepath": "/gluetun/servers/giganews.json"
},
"hidemyass": {
"filepath": "/gluetun/servers/hidemyass.json"
},
"ipvanish": {
"filepath": "/gluetun/servers/ipvanish.json"
},
"ivpn": {
"filepath": "/gluetun/servers/ivpn.json"
},
"mullvad": {
"filepath": "/gluetun/servers/mullvad.json"
},
"nordvpn": {
"filepath": "/gluetun/servers/nordvpn.json"
},
"ovpn": {
"filepath": "/gluetun/servers/ovpn.json"
},
"perfect privacy": {
"filepath": "/gluetun/servers/perfect privacy.json"
},
"privado": {
"filepath": "/gluetun/servers/privado.json"
},
"private internet access": {
"filepath": "/gluetun/servers/private internet access.json"
},
"privatevpn": {
"filepath": "/gluetun/servers/privatevpn.json"
},
"protonvpn": {
"filepath": "/gluetun/servers/protonvpn.json"
},
"purevpn": {
"filepath": "/gluetun/servers/purevpn.json"
},
"slickvpn": {
"filepath": "/gluetun/servers/slickvpn.json"
},
"surfshark": {
"filepath": "/gluetun/servers/surfshark.json"
},
"torguard": {
"filepath": "/gluetun/servers/torguard.json"
},
"vpn unlimited": {
"filepath": "/gluetun/servers/vpn unlimited.json"
},
"vpnsecure": {
"filepath": "/gluetun/servers/vpnsecure.json"
},
"vyprvpn": {
"filepath": "/gluetun/servers/vyprvpn.json"
},
"windscribe": {
"filepath": "/gluetun/servers/windscribe.json"
}
}
-6
View File
@@ -1,19 +1,13 @@
# Maintenance
- Change `Run` methods to `Start`+`Stop`, returning channels rather than injecting them
- Go 1.18
- gofumpt
- Use netip
- Split servers.json
- Common slice of Wireguard providers in config settings
- DNS block lists as LFS and built in image
- Add HTTP server v3 as json rpc
- Use `github.com/qdm12/ddns-updater/pkg/publicip`
- Windows and Darwin development support
## Features
- Authentication with the control server
- Get announcement from Github file
- Support multiple connections in custom ovpn
- Automate IPv6 detection for OpenVPN