Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] c747743d1f Chore(deps): Bump github.com/mdlayher/netlink from 1.9.0 to 1.11.2
Bumps [github.com/mdlayher/netlink](https://github.com/mdlayher/netlink) from 1.9.0 to 1.11.2.
- [Release notes](https://github.com/mdlayher/netlink/releases)
- [Changelog](https://github.com/mdlayher/netlink/blob/main/CHANGELOG.md)
- [Commits](https://github.com/mdlayher/netlink/compare/v1.9.0...v1.11.2)

---
updated-dependencies:
- dependency-name: github.com/mdlayher/netlink
  dependency-version: 1.11.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-24 21:12:04 +00:00
43 changed files with 297 additions and 2542 deletions
+41 -16
View File
@@ -14,6 +14,33 @@ updates:
# servers available
package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"
allow:
- dependency-name: "github.com/qdm12/gluetun-servers"
- # non important dependencies that do not need to be updated.
package-ecosystem: gomod
directory: /
schedule:
interval: "quarterly"
allow:
- dependency-name: "github.com/breml/rootcerts"
- dependency-name: "github.com/fatih/color"
- dependency-name: "github.com/golang/mock"
- dependency-name: "github.com/klauspost/compress"
- dependency-name: "github.com/klauspost/pgzip"
- dependency-name: "github.com/pelletier/go-toml/v2"
- dependency-name: "github.com/qdm12/goshutdown"
- dependency-name: "github.com/qdm12/gosplash"
- dependency-name: "github.com/qdm12/gotree"
- dependency-name: "github.com/qdm12/log"
- dependency-name: "github.com/stretchr/testify"
- dependency-name: "github.com/ulikunitz/xz"
- dependency-name: "gopkg.in/ini.v1"
- # The rest of Go modules are important and should be checked every week,
# instead of daily, to give a bit of time to avoid supply chain attacks.
package-ecosystem: gomod
directory: /
schedule:
interval: "weekly"
ignore:
@@ -23,19 +50,17 @@ updates:
# maintainers, which is persisted on the Go proxy.
dependency-name: "github.com/amnezia-vpn/amneziawg-go"
versions: ["1.x"]
groups:
low-importance:
patterns:
- "github.com/breml/rootcerts"
- "github.com/fatih/color"
- "github.com/golang/mock"
- "github.com/klauspost/compress"
- "github.com/klauspost/pgzip"
- "github.com/pelletier/go-toml/v2"
- "github.com/qdm12/goshutdown"
- "github.com/qdm12/gosplash"
- "github.com/qdm12/gotree"
- "github.com/qdm12/log"
- "github.com/stretchr/testify"
- "github.com/ulikunitz/xz"
- "gopkg.in/ini.v1"
- dependency-name: "github.com/qdm12/gluetun-servers"
- dependency-name: "github.com/breml/rootcerts"
- dependency-name: "github.com/fatih/color"
- dependency-name: "github.com/golang/mock"
- dependency-name: "github.com/klauspost/compress"
- dependency-name: "github.com/klauspost/pgzip"
- dependency-name: "github.com/pelletier/go-toml/v2"
- dependency-name: "github.com/qdm12/goshutdown"
- dependency-name: "github.com/qdm12/gosplash"
- dependency-name: "github.com/qdm12/gotree"
- dependency-name: "github.com/qdm12/log"
- dependency-name: "github.com/stretchr/testify"
- dependency-name: "github.com/ulikunitz/xz"
- dependency-name: "gopkg.in/ini.v1"
+10 -20
View File
@@ -28,10 +28,6 @@ on:
- go.mod
- go.sum
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
verify:
runs-on: ubuntu-latest
@@ -48,6 +44,7 @@ jobs:
locale: "US"
level: error
exclude: |
./internal/storage/servers.json
./.golangci.yml
*.md
@@ -67,10 +64,6 @@ jobs:
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
- name: Run integration tests in test container
run: |
docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet
- name: Verify dev cross platform compatibility
run: docker build --target xcompile .
@@ -105,7 +98,7 @@ jobs:
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
)
needs: [verify]
needs: [ verify ]
runs-on: ubuntu-latest
environment: secrets
steps:
@@ -127,8 +120,7 @@ jobs:
- name: Run Gluetun container with ProtonVPN Wireguard and port forwarding
configuration
run:
echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner
protonvpn-wireguard-port-forwarding
- name: Run Gluetun container with ProtonVPN OpenVPN and port forwarding
@@ -137,12 +129,11 @@ jobs:
secrets.PROTONVPN_OPENVPN_PASSWORD }}" | ./ci/runner
protonvpn-openvpn-port-forwarding
# - name:
# Run Gluetun container with Private Internet Access OpenVPN and port
# forwarding configuration
# run: echo -e "${{ secrets.PRIVATEINTERNETACCESS_OPENVPN_USER }}\n${{
# secrets.PRIVATEINTERNETACCESS_OPENVPN_PASSWORD }}" | ./ci/runner
# private-internet-access-openvpn-port-forwarding
- name: Run Gluetun container with Private Internet Access OpenVPN and port
forwarding configuration
run: echo -e "${{ secrets.PRIVATEINTERNETACCESS_OPENVPN_USER }}\n${{
secrets.PRIVATEINTERNETACCESS_OPENVPN_PASSWORD }}" | ./ci/runner
private-internet-access-openvpn-port-forwarding
- name: Run Gluetun container with AirVPN Wireguard configuration
run: echo -e "${{ secrets.AIRVPN_WIREGUARD_PRIVATE_KEY }}\n${{
@@ -150,8 +141,7 @@ jobs:
secrets.AIRVPN_WIREGUARD_ADDRESSES }}" | ./ci/runner airvpn-wireguard
- name: Run Gluetun container with AirVPN OpenVPN configuration
run:
echo -e "${{ secrets.AIRVPN_OPENVPN_KEY }}\n${{ secrets.AIRVPN_OPENVPN_CERT
run: echo -e "${{ secrets.AIRVPN_OPENVPN_KEY }}\n${{ secrets.AIRVPN_OPENVPN_CERT
}}" | ./ci/runner airvpn-openvpn
codeql:
@@ -179,7 +169,7 @@ jobs:
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
)
needs: [verify, verify-private, codeql]
needs: [ verify, verify-private, codeql ]
permissions:
actions: read
contents: read
-4
View File
@@ -11,10 +11,6 @@ on:
- "**.md"
- .github/workflows/markdown.yml
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
markdown:
runs-on: ubuntu-latest
-4
View File
@@ -12,10 +12,6 @@ formatters:
- builtin$
- examples$
run:
build-tags:
- integration
linters:
settings:
misspell:
+1 -1
View File
@@ -3,7 +3,7 @@
// to develop this project.
"files.eol": "\n",
"editor.formatOnSave": true,
"go.buildTags": "linux,integration",
"go.buildTags": "linux",
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
-5
View File
@@ -50,7 +50,6 @@ Guidance for coding agents working in this repository.
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
- Use `netip` types instead of `net` types whenever possible
- Use constants instead of variables whenever possible, especially function-local inline constants.
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
- `panic`:
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
@@ -116,7 +115,6 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool.
- **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example.
- **Always** set the `.Return(...)` on the mock if the function returns something.
- Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency
- Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests.
### main.go
@@ -129,7 +127,6 @@ The Go formatter used is gofumpt.
### Errors
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
- In rare cases, you can just use `return err` notably:
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
- If the current function only statement is the call to another function, for example:
@@ -182,8 +179,6 @@ The Go formatter used is gofumpt.
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
## Validation checklist
+1 -1
View File
@@ -276,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
PUID=1000 \
PGID=1000
ENTRYPOINT ["/gluetun-entrypoint"]
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp 1080/udp
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck
ARG TARGETPLATFORM
RUN apk add --no-cache --update -l wget && \
+1 -1
View File
@@ -73,7 +73,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
- Choose the vpn network protocol, `udp` or `tcp`
- Built in firewall kill switch to allow traffic only with needed the VPN servers and LAN devices
- Built in Shadowsocks proxy server (protocol based on SOCKS5 with an encryption layer, tunnels TCP+UDP)
- Built in Socks5 proxy server (tunnels TCP+UDP) - partial credits to @angelakis and @adjscent
- Built in Socks5 proxy server (tunnels TCP) - partial credits to @angelakis and @adjscent
- Built in HTTP proxy (tunnels HTTP and HTTPS through TCP)
- [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md)
- [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md)
+2 -2
View File
@@ -12,7 +12,7 @@ require (
github.com/klauspost/compress v1.18.4
github.com/klauspost/pgzip v1.2.6
github.com/mdlayher/genetlink v1.3.2
github.com/mdlayher/netlink v1.9.0
github.com/mdlayher/netlink v1.11.2
github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a
github.com/qdm12/gluetun-servers v0.1.0
@@ -46,7 +46,7 @@ require (
github.com/google/go-cmp v0.7.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/mdlayher/socket v0.6.0 // indirect
github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
+4 -4
View File
@@ -50,10 +50,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.9.0 h1:G8+GLq2x3v4D4MVIqDdNUhTUC7TKiCy/6MDkmItfKco=
github.com/mdlayher/netlink v1.9.0/go.mod h1:YBnl5BXsCoRuwBjKKlZ+aYmEoq0r12FDA/3JC+94KDg=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdlayher/netlink v1.11.2 h1:HKh2jqe+omdSWcQ88nrT7INE61B0NXfiSPFdgL4YbNI=
github.com/mdlayher/netlink v1.11.2/go.mod h1:uT2Yc/QLaZubzDpZIBi9d4GoeLwtp3x1AMeqSRrK2sA=
github.com/mdlayher/socket v0.6.0 h1:ScZPaAGyO1icQnbFrhPM8mnXyMu9qukC1K4ZoM2IQKU=
github.com/mdlayher/socket v0.6.0/go.mod h1:q7vozUAnxSqnjHc12Fik5yUKIzfZ8ITCfMkhOtE9z18=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
-5
View File
@@ -5,7 +5,6 @@ import (
"errors"
"flag"
"fmt"
"net"
"net/http"
"slices"
"strings"
@@ -105,10 +104,6 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
if err != nil {
return fmt.Errorf("creating DoH dialer: %w", err)
}
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: dnsDialer.Dial,
}
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
+4 -10
View File
@@ -9,9 +9,8 @@ import (
)
// Start launches a command and streams stdout and stderr to channels.
// stdoutLines and stderrLines channels will be closed when there is no more
// output to read, in order for the caller to catch all lines even after the
// command has finished. The waitError channel returned will never be closed.
// All the channels returned are ready only and won't be closed
// if the command fails later.
func (c *Cmder) Start(cmd *exec.Cmd) (
stdoutLines, stderrLines <-chan string,
waitError <-chan error, startErr error,
@@ -39,7 +38,6 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
if err != nil {
_ = stdout.Close()
<-stdoutDone
close(stdoutLinesCh)
return nil, nil, nil, err
}
go streamToChannel(stderrReady, stderrDone, stderr, stderrLinesCh)
@@ -47,11 +45,9 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
err = cmd.Start()
if err != nil {
_ = stdout.Close()
<-stdoutDone
close(stdoutLinesCh)
_ = stderr.Close()
<-stdoutDone
<-stderrDone
close(stderrLinesCh)
return nil, nil, nil, err
}
@@ -59,10 +55,8 @@ func start(cmd execCmd) (stdoutLines, stderrLines <-chan string,
go func() {
err := cmd.Wait()
<-stdoutDone
close(stdoutLinesCh)
_ = stdout.Close()
<-stderrDone
close(stderrLinesCh)
_ = stdout.Close()
_ = stderr.Close()
waitErrorCh <- err
}()
+24 -42
View File
@@ -89,48 +89,30 @@ func Test_start(t *testing.T) {
require.NoError(t, err)
collectAndCheckChannels(t, stdoutLines, stderrLines, waitError,
testCase.stdout, testCase.stderr, testCase.waitErr)
var stdoutIndex, stderrIndex int
done := false
for !done {
select {
case line := <-stdoutLines:
assert.Equal(t, testCase.stdout[stdoutIndex], line)
stdoutIndex++
case line := <-stderrLines:
assert.Equal(t, testCase.stderr[stderrIndex], line)
stderrIndex++
case err := <-waitError:
if testCase.waitErr != nil {
require.Error(t, err)
assert.Equal(t, testCase.waitErr.Error(), err.Error())
} else {
assert.NoError(t, err)
}
done = true
}
}
assert.Equal(t, len(testCase.stdout), stdoutIndex)
assert.Equal(t, len(testCase.stderr), stderrIndex)
})
}
}
func collectAndCheckChannels(t *testing.T, stdoutLines, stderrLines <-chan string,
waitError <-chan error, expectedStdout, expectedStderr []string, expectedWaitErr error,
) {
t.Helper()
stdoutIndex := 0
stderrIndex := 0
done := false
for !done {
select {
case line, ok := <-stdoutLines:
if !ok {
stdoutLines = nil
continue
}
assert.Equal(t, expectedStdout[stdoutIndex], line)
stdoutIndex++
case line, ok := <-stderrLines:
if !ok {
stderrLines = nil
continue
}
assert.Equal(t, expectedStderr[stderrIndex], line)
stderrIndex++
case err := <-waitError:
if expectedWaitErr != nil {
require.Error(t, err)
assert.Equal(t, expectedWaitErr.Error(), err.Error())
} else {
assert.NoError(t, err)
}
done = true
}
}
assert.Equal(t, len(expectedStdout), stdoutIndex)
assert.Equal(t, len(expectedStderr), stderrIndex)
}
+13 -19
View File
@@ -18,37 +18,31 @@ func (c *Cmder) RunAndLog(ctx context.Context, command string, logger Logger) (e
return err
}
streamCtx, streamCancel := context.WithCancel(context.Background())
streamDone := make(chan struct{})
go streamLines(streamDone, logger, stdout, stderr)
go streamLines(streamCtx, streamDone, logger, stdout, stderr)
err = <-waitError
streamCancel()
<-streamDone
return err
}
func streamLines(done chan<- struct{}, logger Logger,
stdout, stderr <-chan string,
func streamLines(ctx context.Context, done chan<- struct{},
logger Logger, stdout, stderr <-chan string,
) {
defer close(done)
var line string
for {
select {
case line, ok := <-stdout:
if ok {
logger.Info(line)
}
if stderr == nil {
return
}
stdout = nil
case line, ok := <-stderr:
if ok {
logger.Error(line)
}
if stdout == nil {
return
}
stderr = nil
case <-ctx.Done():
return
case line = <-stdout:
logger.Info(line)
case line = <-stderr:
logger.Error(line)
}
}
}
-2
View File
@@ -28,8 +28,6 @@ type firewallImpl interface { //nolint:interfacebloat
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutput(ctx context.Context, protocol, intf string,
ip netip.Addr, port uint16, remove bool) error
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
source, destination netip.AddrPort, remove bool) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
-24
View File
@@ -2,7 +2,6 @@ package iptables
import (
"context"
"errors"
"fmt"
"io"
"net/netip"
@@ -178,29 +177,6 @@ func (c *Config) AcceptOutput(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
if source.Addr().BitLen() != destination.Addr().BitLen() {
return errors.New("source and destination address families do not match")
}
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(),
protocol, protocol, source.Port(), destination.Port())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
-7
View File
@@ -25,10 +25,3 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
) error {
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
}
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
source, destination, remove)
}
+4 -1
View File
@@ -29,16 +29,19 @@ func (r *Runner) Run(ctx context.Context, errCh chan<- error, ready chan<- struc
return
}
streamCtx, streamCancel := context.WithCancel(context.Background())
streamDone := make(chan struct{})
go streamLines(streamDone, r.logger,
go streamLines(streamCtx, streamDone, r.logger,
stdoutLines, stderrLines, ready)
select {
case <-ctx.Done():
<-waitError
streamCancel()
<-streamDone
errCh <- ctx.Err()
case err := <-waitError:
streamCancel()
<-streamDone
errCh <- err
}
+9 -20
View File
@@ -1,37 +1,26 @@
package openvpn
import (
"context"
"strings"
)
func streamLines(done chan<- struct{},
func streamLines(ctx context.Context, done chan<- struct{},
logger Logger, stdout, stderr <-chan string,
tunnelReady chan<- struct{},
) {
defer close(done)
var line string
for {
var line string
var ok bool
errLine := false
select {
case line, ok = <-stdout:
if ok {
break
}
if stderr == nil {
return
}
stdout = nil
case line, ok = <-stderr:
if ok {
errLine = true
break
}
if stdout == nil {
return
}
stderr = nil
case <-ctx.Done():
return
case line = <-stdout:
case line = <-stderr:
errLine = true
}
line, level := processLogLine(line)
if line == "" {
+1 -6
View File
@@ -6,7 +6,6 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
@@ -66,11 +65,7 @@ func modifyConfig(lines []string, connection models.Connection,
}
// Add values
protocol := connection.Protocol
if protocol == constants.TCP {
protocol = "tcp-client"
}
modified = append(modified, "proto "+protocol)
modified = append(modified, "proto "+connection.Protocol)
modified = append(modified, fmt.Sprintf("remote %s %d", connection.IP, connection.Port))
modified = append(modified, "dev "+settings.Interface)
modified = append(modified, "mute-replay-warnings")
-82
View File
@@ -1,82 +0,0 @@
package restrictednet
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"strconv"
"github.com/qdm12/dns/v2/pkg/provider"
)
// Client is a client for making restricted network requests,
// such as opening temporary firewall rules for HTTPS connections.
// It is not meant to be high performance, although it can be used for
// multiple requests and concurrently.
type Client struct {
outboundInterface string
ipv6Supported bool
firewall Firewall
dohServers []provider.DoHServer
}
func New(settings Settings) *Client {
if err := settings.validate(); err != nil {
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
}
dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers))
for i, upstreamResolver := range settings.UpstreamResolvers {
dohServers[i] = upstreamResolver.DoH
}
return &Client{
outboundInterface: settings.DefaultInterface,
ipv6Supported: *settings.IPv6Supported,
firewall: settings.Firewall,
dohServers: dohServers,
}
}
// OpenHTTPSByHostname opens an https connection through the firewall,
// to the hostname which in the format `host:port`. The returned cleanup
// function must be called to remove the temporary firewall rule and close connections.
// It first resolves the domain in hostname using DNS over HTTPS and then opens
// the restricted HTTPS connection to the resolved IP.
func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) (
httpClient *http.Client, cleanup func() error, err error,
) {
host, portStr, err := net.SplitHostPort(hostname)
if err != nil {
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
}
resolvedIPs, err := c.ResolveName(ctx, host)
if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", host)
}
portUint, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, nil, fmt.Errorf("parsing port: %w", err)
} else if portUint == 0 {
return nil, nil, errors.New("destination port cannot be 0")
}
port := uint16(portUint)
errs := make([]error, 0, len(resolvedIPs))
for _, ip := range resolvedIPs {
addrPort := netip.AddrPortFrom(ip, port)
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
if err != nil {
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
continue
}
return httpClient, cleanup, nil
}
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
}
-7
View File
@@ -1,7 +0,0 @@
//go:build integration
package restrictednet
func ptrTo[T any](value T) *T {
return &value
}
-202
View File
@@ -1,202 +0,0 @@
package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"os"
"time"
"github.com/jsimonetti/rtnetlink"
"github.com/qdm12/gluetun/internal/pmtud/constants"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned [*http.Client] must be used sequentially only, and each request must
// have its response body fully read/discarded and then closed.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
) (httpClient *http.Client, cleanup func() error, err error) {
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const remove = false
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
closeFD(fd)
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
connection, err := connectSourceConnection(ctx, fd, destinationAddrPort)
if err != nil {
const remove = true
_ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
}
dial := makeDial(connection, destinationTLSName)
httpClient = newHTTPSClient(destinationTLSName, dial)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
err := connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing connection: %w", err))
}
const remove = true
err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
const timeout = 5 * time.Second
transport := &http.Transport{
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
},
DialContext: dial,
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func makeDial(connection net.Conn, tlsName string) dialFunc {
_, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String())
if err != nil {
panic(err) // connection remote address should always be in the form "host:port"
}
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
used := false
return func(_ context.Context, network, address string) (net.Conn, error) {
if used {
return nil, errors.New("dial function called more than once")
}
used = true
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unexpected dial network %q", network)
}
if address != expectedAddress {
return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress)
}
return connection, nil
}
}
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
sourceIP, err := sourceIPForDestination(destinationIP)
if err != nil {
return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err)
}
family := constants.AF_INET
if sourceIP.Is6() {
family = constants.AF_INET6
}
fd, err = newTCPSockStream(family)
if err != nil {
return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err)
}
bindAddrPort := netip.AddrPortFrom(sourceIP, 0)
err = bindFD(fd, bindAddrPort)
if err != nil {
closeFD(fd)
return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err)
}
sourceAddr, err = fdToSourceAddr(fd)
if err != nil {
closeFD(fd)
return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err)
}
return fd, sourceAddr, nil
}
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
connection net.Conn, err error,
) {
err = connectFD(ctx, fd, destinationAddrPort)
if err != nil {
closeFD(fd)
return nil, fmt.Errorf("connecting socket: %w", err)
}
file := os.NewFile(uintptr(fd), "")
if file == nil {
closeFD(fd)
return nil, fmt.Errorf("creating socket file")
}
defer file.Close()
connection, err = net.FileConn(file)
if err != nil {
return nil, fmt.Errorf("wrapping socket connection: %w", err)
}
return connection, nil
}
func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return netip.Addr{}, err
}
defer conn.Close()
family := uint8(constants.AF_INET)
if destinationIP.Is6() {
family = constants.AF_INET6
}
requestMessage := &rtnetlink.RouteMessage{
Family: family,
Attributes: rtnetlink.RouteAttributes{
Dst: destinationIP.AsSlice(),
},
}
messages, err := conn.Route.Get(requestMessage)
if err != nil {
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err)
}
for _, message := range messages {
if message.Attributes.Src == nil {
continue
}
if message.Attributes.Src.To4() == nil {
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
}
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
}
return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP)
}
@@ -1,117 +0,0 @@
//go:build integration
package restrictednet
import (
"context"
"fmt"
"io"
"net/http"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type listenAddrPortMatcher struct {
expected netip.AddrPort
}
func (m listenAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
}
func (m listenAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "is a valid netip.AddrPort with a valid IP and non-zero port"
}
type destinationAddrPortMatcher struct {
expected netip.AddrPort
}
func (m destinationAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Port() == m.expected.Port()
}
func (m destinationAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "matches the port " + fmt.Sprint(m.expected.Port())
}
func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
const destinationTLSName = "one.one.one.one"
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
firewall := NewMockFirewall(ctrl)
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
).Return(nil)
const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()}
settings := Settings{
Firewall: firewall,
DefaultInterface: "eth0",
IPv6Supported: ptrTo(ipv6Supported),
UpstreamResolvers: upstreamResolvers,
}
client := New(settings)
httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)
const requests = 2
for range requests {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
require.NoError(t, err)
response, err := httpClient.Do(request)
require.NoError(t, err)
_, err = io.Copy(io.Discard, response.Body)
require.NoError(t, err)
err = response.Body.Close()
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)
}
err = cleanup()
require.NoError(t, err)
}
-12
View File
@@ -1,12 +0,0 @@
package restrictednet
import (
"context"
"net/netip"
)
type Firewall interface {
AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error
}
@@ -1,3 +0,0 @@
package restrictednet
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
-50
View File
@@ -1,50 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
// Package restrictednet is a generated GoMock package.
package restrictednet
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutputFromIPPortToIPPort mocks base method.
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
}
-205
View File
@@ -1,205 +0,0 @@
package restrictednet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"strconv"
"github.com/miekg/dns"
)
// ResolveName resolves the given host name to IP addresses using DoH servers,
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
// The host must be a single well-formed domain name, without port or path.
func (c *Client) ResolveName(ctx context.Context, host string) (
resolvedAddresses []netip.Addr, err error,
) {
const maxTypes = 2
questionTypes := make([]uint16, 0, maxTypes)
if c.ipv6Supported {
questionTypes = append(questionTypes, dns.TypeAAAA)
}
questionTypes = append(questionTypes, dns.TypeA)
var addresses []netip.Addr
errs := make([]error, 0, len(questionTypes))
for _, questionType := range questionTypes {
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
if err != nil {
errs = append(errs, err)
continue
}
addresses = append(addresses, answerAddresses...)
}
switch {
case len(addresses) > 0:
return addresses, nil
case len(errs) == 0:
return nil, nil // no address found
default: // errors
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
}
}
func (c *Client) resolveOneQuestionType(ctx context.Context,
host string, questionType uint16,
) (addresses []netip.Addr, err error) {
queryMessage := &dns.Msg{}
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
queryWire, err := queryMessage.Pack()
if err != nil {
return nil, fmt.Errorf("packing DNS query: %w", err)
}
// Try every DoH server and every of each of their IP until we get a non-empty
// successful response.
errs := make([]error, 0)
for _, dohServer := range c.dohServers {
dohURL, err := url.Parse(dohServer.URL)
if err != nil {
errs = append(errs,
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
continue
}
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
if c.ipv6Supported {
// Prefer IPv6 addresses if IPv6 is supported
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
}
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
for _, dohServerIP := range dohServerIPs {
const defaultDoHPort uint16 = 443
port := defaultDoHPort
if portStr := dohURL.Port(); portStr != "" {
port, err = parseDestinationPort(portStr)
if err != nil {
errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err))
continue
}
}
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port)
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
switch {
case err != nil:
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
dohServer.URL, dohServerAddrPort, err))
continue
case responseMessage.Rcode != dns.RcodeSuccess:
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
continue
}
addresses := answersToNetipAddrs(responseMessage)
if len(addresses) == 0 {
continue
}
return addresses, nil
}
}
if len(errs) == 0 {
return nil, nil
}
return nil, fmt.Errorf("resolving %s %s: %w",
dns.TypeToString[questionType], host, errors.Join(errs...))
}
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerAddrPort netip.AddrPort,
) (responseMessage *dns.Msg, err error) {
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err)
}
defer func() {
closeErr := cleanup()
if err == nil && closeErr != nil {
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
}
}()
requestBody := bytes.NewReader(queryWire)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
request.Header.Set("Content-Type", "application/dns-message")
request.Header.Set("Accept", "application/dns-message")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
responseData, err := io.ReadAll(response.Body)
if err != nil {
_ = response.Body.Close()
return nil, fmt.Errorf("reading response body: %w", err)
}
err = response.Body.Close()
if err != nil {
return nil, fmt.Errorf("closing response body: %w", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response status code is %s (data length %d)",
response.Status, len(responseData))
}
responseMessage = new(dns.Msg)
err = responseMessage.Unpack(responseData)
if err != nil {
return nil, fmt.Errorf("parsing DoH response: %w", err)
}
return responseMessage, nil
}
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
if message == nil {
return nil
}
addresses = make([]netip.Addr, 0, len(message.Answer))
for _, answer := range message.Answer {
switch record := answer.(type) {
case *dns.A:
address, ok := netip.AddrFromSlice(record.A)
if ok {
addresses = append(addresses, address.Unmap())
}
case *dns.AAAA:
address, ok := netip.AddrFromSlice(record.AAAA)
if ok {
addresses = append(addresses, address)
}
}
}
return addresses
}
func parseDestinationPort(portStr string) (port uint16, err error) {
portUint, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return 0, err
}
const maxPortUint = 65535
switch {
case portUint == 0:
return 0, errors.New("port cannot be 0")
case portUint > maxPortUint:
return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint)
}
return uint16(portUint), nil
}
@@ -1,110 +0,0 @@
//go:build integration
package restrictednet
import (
"context"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Client_ResolveName(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
sourceMatcher := listenAddrPortMatcher{}
destinationMatcher := destinationAddrPortMatcher{
expected: netip.AddrPortFrom(netip.Addr{}, 443),
}
// Add rule
firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
).DoAndReturn(func(
_ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
destinationMatcher.expected = destination
return nil
})
// Removal rule
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
).Return(nil).After(firstCall)
settings := Settings{
DefaultInterface: "eth0",
IPv6Supported: ptrTo(false),
Firewall: firewall,
UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
}
client := New(settings)
addresses, err := client.ResolveName(ctx, "github.com")
require.NoError(t, err)
assert.NotEmpty(t, addresses)
}
func Test_answersToNetipAddrs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
message *dns.Msg
expected []netip.Addr
}{
"nil_message": {},
"no_answers": {
message: &dns.Msg{},
expected: []netip.Addr{},
},
"a_record": {
message: &dns.Msg{Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
}},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
},
"aaaa_record": {
message: &dns.Msg{Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
}},
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
},
"mixed_records": {
message: &dns.Msg{Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
}},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
},
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
addresses := answersToNetipAddrs(testCase.message)
assert.Equal(t, testCase.expected, addresses)
})
}
}
-28
View File
@@ -1,28 +0,0 @@
package restrictednet
import (
"errors"
"github.com/qdm12/dns/v2/pkg/provider"
)
type Settings struct {
DefaultInterface string
IPv6Supported *bool
Firewall Firewall
UpstreamResolvers []provider.Provider
}
func (s *Settings) validate() error {
switch {
case s.DefaultInterface == "":
return errors.New("default interface is not set")
case s.IPv6Supported == nil:
return errors.New("IPv6 support field is not set")
case s.Firewall == nil:
return errors.New("firewall is not set")
case len(s.UpstreamResolvers) == 0:
return errors.New("no upstream resolvers provided")
}
return nil
}
-121
View File
@@ -1,121 +0,0 @@
//go:build !windows
package restrictednet
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"golang.org/x/sys/unix"
)
func closeFD(fd int) {
unix.Close(fd)
}
func newTCPSockStream(family int) (fd int, err error) {
fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP)
if err != nil {
return 0, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
_ = unix.Close(fd)
return 0, err
}
return fd, nil
}
func bindFD(fd int, address netip.AddrPort) error {
bindAddr := makeSockAddr(address)
return unix.Bind(fd, bindAddr)
}
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
err := unix.Connect(fd, makeSockAddr(destination))
switch {
case err == nil:
return nil
case !errors.Is(err, unix.EINPROGRESS):
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
bitsIndex := fd / 64 //nolint:mnd
if bitsIndex >= len(unix.FdSet{}.Bits) {
return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd)
}
wset := &unix.FdSet{}
wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
eset := &unix.FdSet{}
eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
const selectTimeout = 50 * time.Millisecond
timeval := unix.NsecToTimeval(int64(selectTimeout))
// Wait for the FD to become writable or hit an error state
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue // Syscall interrupted, try again
}
return fmt.Errorf("select error: %w", err)
} else if n == 0 {
continue // no status change yet
}
// Check if the socket encountered an error
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
if err != nil {
return fmt.Errorf("getsockopt error: %w", err)
} else if n != 0 {
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
}
return nil
}
}
}
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
sockAddr, err := unix.Getsockname(fd)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err)
}
sourceAddrPort, err = sockAddrToAddrPort(sockAddr)
if err != nil {
return netip.AddrPort{}, err
}
return sourceAddrPort, nil
}
func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr {
if addressPort.Addr().Is4() {
return &unix.SockaddrInet4{
Port: int(addressPort.Port()),
Addr: addressPort.Addr().As4(),
}
}
return &unix.SockaddrInet6{
Port: int(addressPort.Port()),
Addr: addressPort.Addr().As16(),
}
}
func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) {
switch typedSockAddr := sockAddr.(type) {
case *unix.SockaddrInet4:
return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
case *unix.SockaddrInet6:
return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
default:
return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr)
}
}
-28
View File
@@ -1,28 +0,0 @@
//go:build windows
package restrictednet
import (
"context"
"net/netip"
)
func closeFD(fd int) {
panic("not implemented")
}
func newTCPSockStream(family int) (fd int, err error) {
panic("not implemented")
}
func bindFD(fd int, address netip.AddrPort) error {
panic("not implemented")
}
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
panic("not implemented")
}
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
panic("not implemented")
}
+3
View File
@@ -43,6 +43,7 @@ type cmdType byte
const (
connect cmdType = 1
bind cmdType = 2
udpAssociate cmdType = 3
)
@@ -50,6 +51,8 @@ func (c cmdType) String() string {
switch c {
case connect:
return "connect"
case bind:
return "bind"
case udpAssociate:
return "UDP associate"
default:
+1 -1
View File
@@ -10,7 +10,7 @@ import (
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) {
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam
_, err := writer.Write([]byte{
socksVersion,
byte(reply),
+49 -89
View File
@@ -2,7 +2,6 @@ package socks5
import (
"context"
"errors"
"fmt"
"net"
"sync"
@@ -16,13 +15,12 @@ type server struct {
logger Logger
// internal fields
tcpListener net.Listener
udpRouter *udpRouter
listener net.Listener
listening atomic.Bool
socksConnCtx context.Context //nolint:containedctx
socksConnCancel context.CancelFunc
done <-chan error
stopCh chan<- struct{}
done <-chan struct{}
stopping atomic.Bool
}
func newServer(settings Settings) *server {
@@ -41,28 +39,19 @@ func (s *server) String() string {
func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
config := &net.ListenConfig{}
s.tcpListener, err = config.Listen(ctx, "tcp", s.address)
s.listener, err = config.Listen(ctx, "tcp", s.address)
if err != nil {
return nil, fmt.Errorf("TCP listening on %s: %w", s.address, err)
}
s.udpRouter, err = newUDPRouter(ctx, s.address, s.logger)
if err != nil {
_ = s.tcpListener.Close()
return nil, fmt.Errorf("creating UDP router: %w", err)
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
}
s.listening.Store(true)
s.logger.Infof("SOCKS5 TCP server listening on %s", s.tcpListener.Addr())
s.logger.Infof("SOCKS5 UDP server listening on %s", s.udpRouter.localAddress())
s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr())
ready := make(chan struct{})
runErrCh := make(chan error)
runErr = runErrCh
done := make(chan error)
done := make(chan struct{})
s.done = done
stop := make(chan struct{})
s.stopCh = stop
go s.runServer(ready, runErrCh, stop, done)
go s.runServer(ready, runErrCh, done)
select {
case <-ready:
case <-ctx.Done():
@@ -73,90 +62,61 @@ func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
}
func (s *server) runServer(ready chan<- struct{},
runErrCh chan<- error, stop <-chan struct{}, done chan<- error,
runErrCh chan<- error, done chan<- struct{},
) {
close(ready)
defer close(done)
wg := new(sync.WaitGroup)
defer wg.Wait()
udpErrCh := make(chan error)
go func() {
udpErrCh <- s.udpRouter.run(s.socksConnCtx)
}()
tcpErrCh := make(chan error)
go func() {
var wg sync.WaitGroup
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.tcpListener.Accept()
if err != nil {
s.socksConnCancel() // stop ongoing TCP socks connections - no impact on UDP
tcpErrCh <- fmt.Errorf("accepting connection: %w", err)
return
dialer := &net.Dialer{}
for {
connection, err := s.listener.Accept()
if err != nil {
if !s.stopping.Load() {
_ = s.stop()
runErrCh <- fmt.Errorf("accepting connection: %w", err)
}
wg.Go(func() {
connection := connection // capture loop variable
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
udpRouter: s.udpRouter,
logger: s.logger,
}
err := socksConn.run(s.socksConnCtx)
if err != nil {
s.logger.Infof("running socks connection: %s", err)
}
})
return
}
}()
select {
case <-stop:
s.listening.Store(false)
var errs []error
err := s.tcpListener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing TCP listener: %w", err))
}
// stop ongoing TCP socks connections. This impacts the udpRouter run error when it is being closed.
s.socksConnCancel()
<-tcpErrCh // wait for TCP server to stop
err = s.udpRouter.close()
if err != nil {
errs = append(errs, fmt.Errorf("closing UDP router: %w", err))
}
<-udpErrCh // wait for UDP router to stop
if len(errs) > 0 {
// Only write to the done channel if the [server.Stop] method is waiting to read from it
done <- errors.Join(errs...)
}
// If no error, the done channel is closed so the error is effectively `nil`
// Note: do NOT write an error the runError channel, since we are stopping the server gracefully.
case err := <-udpErrCh:
_ = s.tcpListener.Close() // stop accepting new TCP connections
s.socksConnCancel() // stop ongoing TCP socks connections
<-tcpErrCh // wait for TCP server to stop
runErrCh <- fmt.Errorf("running UDP router: %w", err)
case err := <-tcpErrCh:
s.socksConnCancel()
_ = s.udpRouter.close() // stop UDP router
<-udpErrCh // wait for UDP router to stop
runErrCh <- fmt.Errorf("running TCP server: %w", err)
wg.Add(1)
go func(ctx context.Context, connection net.Conn,
dialer *net.Dialer, wg *sync.WaitGroup,
) {
defer wg.Done()
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
logger: s.logger,
}
err := socksConn.run(ctx)
if err != nil {
s.logger.Infof("running socks connection: %s", err)
}
}(s.socksConnCtx, connection, dialer, wg)
}
}
func (s *server) Stop() (err error) {
close(s.stopCh)
return <-s.done
s.stopping.Store(true)
err = s.stop()
<-s.done // wait for run goroutine to finish
s.stopping.Store(false)
return err
}
func (s *server) stop() error {
s.listening.Store(false)
err := s.listener.Close()
s.socksConnCancel() // stop ongoing socks connections
return err
}
func (s *server) listeningAddress() net.Addr {
if s.listening.Load() {
return s.tcpListener.Addr()
return s.listener.Addr()
}
return nil
}
+1 -249
View File
@@ -10,7 +10,6 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
)
var (
@@ -24,7 +23,6 @@ type socksConn struct {
username string
password string
clientConn net.Conn
udpRouter *udpRouter
logger Logger
}
@@ -111,29 +109,11 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
switch request.command {
case connect:
err = c.handleConnectRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
case udpAssociate:
err = c.handleUDPAssociateRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
default:
if request.command != connect {
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
return fmt.Errorf("command %s is not supported", request.command)
}
}
func (c *socksConn) handleConnectRequest(ctx context.Context,
socksVersion byte, request request,
) error {
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
if err != nil {
@@ -196,234 +176,6 @@ func (c *socksConn) handleConnectRequest(ctx context.Context,
}
}
func (c *socksConn) handleUDPAssociateRequest(ctx context.Context,
socksVersion byte, request request,
) error {
expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported)
return fmt.Errorf("deriving expected client address and port from request: %w", err)
}
bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses()
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("getting udp association addresses: %w", err)
}
association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("registering udp association: %w", err)
}
defer c.udpRouter.unregisterAssociation(association)
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded,
bindAddrType, bindAddress, bindPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("writing successful %s response: %w", udpAssociate, err)
}
associationCtx, associationCancel := context.WithCancel(ctx)
defer associationCancel()
var wg sync.WaitGroup
wg.Go(func() {
c.udpRouter.runAssociationHandler(associationCtx, association)
})
wg.Go(func() {
_, _ = io.Copy(io.Discard, c.clientConn)
associationCancel()
})
<-associationCtx.Done()
wg.Wait()
return nil
}
func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) {
switch request.addressType {
case ipv4, ipv6:
expectedClientAddress, parseErr := netip.ParseAddr(request.destination)
if parseErr != nil {
return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr)
}
expectedClientAddress = expectedClientAddress.Unmap()
if !expectedClientAddress.IsUnspecified() {
return netip.AddrPortFrom(expectedClientAddress, request.port), nil
}
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
case domainName:
// For UDP associate, client endpoint matching is based on observed UDP source
// address/port. A hostname is not directly matchable at this stage, so we
// ignore the domain name request destination entirely.
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
default:
return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType)
}
}
func (c *socksConn) udpAssociationAddresses() (bindAddress string,
bindPort uint16, bindAddrType addrType, err error,
) {
localAddress := c.udpRouter.localAddress().String()
host, portString, err := net.SplitHostPort(localAddress)
if err != nil {
return "", 0, 0, fmt.Errorf("splitting local address: %w", err)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return "", 0, 0, fmt.Errorf("parsing local port: %w", err)
}
bindAddress = host
bindPort = uint16(port)
if isUnspecifiedIPAddress(bindAddress) {
controlLocalAddress := c.clientConn.LocalAddr().String()
controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress)
if splitErr != nil {
return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr)
}
bindAddress = controlLocalHost
}
ipAddress := net.ParseIP(bindAddress)
if ipAddress == nil {
bindAddrType = domainName
return bindAddress, bindPort, bindAddrType, nil
}
if ipAddress.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
return bindAddress, bindPort, bindAddrType, nil
}
func isUnspecifiedIPAddress(address string) bool {
ipAddress, err := netip.ParseAddr(address)
if err != nil {
return false
}
return ipAddress.IsUnspecified()
}
func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) {
const minimumPacketLength = 4
if len(packet) < minimumPacketLength {
return "", nil, fmt.Errorf("packet is too short: %d", len(packet))
}
if packet[0] != 0 || packet[1] != 0 {
return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1])
}
if packet[2] != 0 {
return "", nil, fmt.Errorf("fragmentation is not supported")
}
offset := 3
addressType := addrType(packet[offset])
offset++
switch addressType {
case ipv4:
const ipv4Length = 4
if len(packet) < offset+ipv4Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv4 address")
}
var ip [ipv4Length]byte
copy(ip[:], packet[offset:offset+ipv4Length])
destination = netip.AddrFrom4(ip).String()
offset += ipv4Length
case ipv6:
const ipv6Length = 16
if len(packet) < offset+ipv6Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv6 address")
}
var ip [ipv6Length]byte
copy(ip[:], packet[offset:offset+ipv6Length])
destination = netip.AddrFrom16(ip).String()
offset += ipv6Length
case domainName:
if len(packet) < offset+1 {
return "", nil, fmt.Errorf("packet is too short for domain name length")
}
domainNameLength := int(packet[offset])
offset++
if len(packet) < offset+domainNameLength+2 {
return "", nil, fmt.Errorf("packet is too short for domain name")
}
destination = string(packet[offset : offset+domainNameLength])
offset += domainNameLength
default:
return "", nil, fmt.Errorf("address type is not supported: %d", addressType)
}
port := binary.BigEndian.Uint16(packet[offset : offset+2])
destination = net.JoinHostPort(destination, fmt.Sprint(port))
offset += 2
payload = packet[offset:]
return destination, payload, nil
}
func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort,
payload []byte,
) error {
address := sourceAddrPort.Addr()
if !address.IsValid() {
return errors.New("source address is not valid")
}
err := writeUDPDatagramSourceAddress(writer, address)
if err != nil {
return fmt.Errorf("writing source address: %w", err)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port())
_, err = writer.Write(portBytes[:])
if err != nil {
return fmt.Errorf("writing destination port: %w", err)
}
_, err = writer.Write(payload)
if err != nil {
return fmt.Errorf("writing payload: %w", err)
}
return nil
}
func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error {
var addrType addrType
var addressBytes []byte
switch {
case address.Is4():
addrType = ipv4
array := address.As4()
addressBytes = array[:]
case address.Is6():
addrType = ipv6
array := address.As16()
addressBytes = array[:]
default:
return fmt.Errorf("address type is not supported: %v", address)
}
_, err := writer.Write([]byte{0, 0, 0, byte(addrType)})
if err != nil {
return fmt.Errorf("writing header: %w", err)
}
_, err = writer.Write(addressBytes)
if err != nil {
return fmt.Errorf("writing IP address: %w", err)
}
return nil
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
const headerLength = 2 // version + nMethods bytes
+33 -497
View File
@@ -2,13 +2,9 @@ package socks5
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"testing"
@@ -100,178 +96,6 @@ func TestServerProxy(t *testing.T) {
}
}
func TestServerProxyTCPAndUDPParallel(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
username string
password string
}{
"no_auth": {},
"with_auth": {
username: "user",
password: "pass",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
backendTCPListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
backendTCPConnChannel := make(chan net.Conn, 1)
go func() {
connection, err := backendTCPListener.Accept()
if err != nil {
return
}
backendTCPConnChannel <- connection
}()
backendUDPPacketConn, err := (&net.ListenConfig{}).ListenPacket(t.Context(), "udp", "127.0.0.1:0")
require.NoError(t, err)
server := newServer(Settings{
Username: testCase.username,
Password: testCase.password,
Address: "127.0.0.1:0",
Logger: noopLogger{},
})
_, err = server.Start(t.Context())
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
_ = backendTCPListener.Close()
_ = backendUDPPacketConn.Close()
})
clientTCPConn := dialSOCKS5(t, server.listeningAddress().String(),
backendTCPListener.Addr().String(), testCase.username, testCase.password)
defer clientTCPConn.Close()
backendTCPConn := <-backendTCPConnChannel
defer backendTCPConn.Close()
udpControlConn, clientUDPConn := dialSOCKS5UDPAssociate(t,
server.listeningAddress().String(), testCase.username, testCase.password)
defer udpControlConn.Close()
defer clientUDPConn.Close()
tcpErrCh := make(chan error, 1)
go func() {
tcpErrCh <- runTCPProxyRoundTrip(clientTCPConn, backendTCPConn)
}()
udpErrCh := make(chan error, 1)
go func() {
udpErrCh <- runUDPProxyRoundTrip(t.Context(), clientUDPConn, backendUDPPacketConn)
}()
err = <-tcpErrCh
require.NoError(t, err)
err = <-udpErrCh
require.NoError(t, err)
})
}
}
func runTCPProxyRoundTrip(clientTCPConn net.Conn, backendTCPConn net.Conn) error {
clientMessage := []byte("hello from client")
_, err := clientTCPConn.Write(clientMessage)
if err != nil {
return err
}
received := make([]byte, len(clientMessage))
_, err = io.ReadFull(backendTCPConn, received)
if err != nil {
return err
}
if !bytes.Equal(clientMessage, received) {
return errors.New("backend did not receive expected TCP payload")
}
backendMessage := []byte("hello from backend")
_, err = backendTCPConn.Write(backendMessage)
if err != nil {
return err
}
receivedByClient := make([]byte, len(backendMessage))
_, err = io.ReadFull(clientTCPConn, receivedByClient)
if err != nil {
return err
}
if !bytes.Equal(backendMessage, receivedByClient) {
return errors.New("client did not receive expected TCP payload")
}
return nil
}
func runUDPProxyRoundTrip(ctx context.Context, clientUDPConn *net.UDPConn, backendUDPPacketConn net.PacketConn) error {
udpPayload := []byte("hello from udp client")
udpRequest, err := makeSOCKS5UDPDatagram(backendUDPPacketConn.LocalAddr().String(), udpPayload)
if err != nil {
return err
}
_, err = clientUDPConn.Write(udpRequest)
if err != nil {
return err
}
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
err = backendUDPPacketConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on backend connection: %w", err)
}
}
const bufferSize = 512
backendReadBuffer := make([]byte, bufferSize)
packetLength, proxyAddress, err := backendUDPPacketConn.ReadFrom(backendReadBuffer)
if err != nil {
return err
}
if !bytes.Equal(udpPayload, backendReadBuffer[:packetLength]) {
return errors.New("backend did not receive expected UDP payload")
}
backendUDPReply := []byte("hello from udp backend")
_, err = backendUDPPacketConn.WriteTo(backendUDPReply, proxyAddress)
if err != nil {
return err
}
if hasDeadline {
err = clientUDPConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on client connection: %w", err)
}
}
udpResponseBuffer := make([]byte, 1024)
responseLength, err := clientUDPConn.Read(udpResponseBuffer)
if err != nil {
return err
}
destinationAddress, udpResponsePayload, err := parseSOCKS5UDPDatagram(udpResponseBuffer[:responseLength])
if err != nil {
return err
}
if !bytes.Equal(backendUDPReply, udpResponsePayload) {
return errors.New("client did not receive expected UDP payload")
}
if destinationAddress != backendUDPPacketConn.LocalAddr().String() {
return errors.New("udp response destination address mismatch")
}
return nil
}
// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password
// subnegotiation) and returns a connected net.Conn ready for data exchange.
func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn {
@@ -285,55 +109,6 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, conn, username, password)
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
_, err = readSOCKS5ResponseAddress(t, conn)
require.NoError(t, err)
return conn
}
func dialSOCKS5UDPAssociate(t *testing.T, proxyAddr, username, password string) (net.Conn, *net.UDPConn) {
t.Helper()
controlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, controlConn, username, password)
udpAssociateRequest := []byte{socks5Version, byte(udpAssociate), 0, byte(ipv4), 0, 0, 0, 0, 0, 0}
_, err = controlConn.Write(udpAssociateRequest)
require.NoError(t, err)
udpProxyAddress, err := readSOCKS5ResponseAddress(t, controlConn)
require.NoError(t, err)
udpProxyResolvedAddress, err := net.ResolveUDPAddr("udp", udpProxyAddress)
require.NoError(t, err)
udpConn, err := net.DialUDP("udp", nil, udpProxyResolvedAddress)
require.NoError(t, err)
return controlConn, udpConn
}
func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) {
t.Helper()
var err error
var method authMethod
if username != "" || password != "" {
method = authUsernamePassword
@@ -363,146 +138,45 @@ func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) {
require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0])
require.Equal(t, byte(0), subnegResp[1])
}
}
func readSOCKS5ResponseAddress(t *testing.T, conn net.Conn) (address string, err error) {
t.Helper()
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
var responseHeader [4]byte
_, err = io.ReadFull(conn, responseHeader[:])
if err != nil {
return "", err
}
if responseHeader[0] != socks5Version {
return "", errors.New("version mismatch")
}
if responseHeader[1] != byte(succeeded) {
return "", errors.New("request was not successful")
}
require.NoError(t, err)
require.Equal(t, socks5Version, responseHeader[0])
require.Equal(t, byte(succeeded), responseHeader[1])
var host string
// Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller).
switch addrType(responseHeader[3]) {
case ipv4:
addressAndPort := make([]byte, net.IPv4len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv4len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv4len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
var addrPort [net.IPv4len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case ipv6:
addressAndPort := make([]byte, net.IPv6len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv6len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv6len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
var addrPort [net.IPv6len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case domainName:
var lengthBuffer [1]byte
_, err = io.ReadFull(conn, lengthBuffer[:])
if err != nil {
return "", err
}
domainAndPort := make([]byte, int(lengthBuffer[0])+2)
_, err = io.ReadFull(conn, domainAndPort)
if err != nil {
return "", err
}
host = string(domainAndPort[:len(domainAndPort)-2])
port := binary.BigEndian.Uint16(domainAndPort[len(domainAndPort)-2:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
default:
return "", errors.New("unknown address type")
}
}
func makeSOCKS5UDPDatagram(targetAddress string, payload []byte) ([]byte, error) {
host, portString, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
var lenBuf [1]byte
_, err = io.ReadFull(conn, lenBuf[:])
require.NoError(t, err)
addrPort := make([]byte, int(lenBuf[0])+2)
_, err = io.ReadFull(conn, addrPort)
require.NoError(t, err)
}
datagram := []byte{0, 0, 0}
ipAddress := net.ParseIP(host)
if ipAddress != nil {
if ipAddress.To4() != nil {
datagram = append(datagram, byte(ipv4))
datagram = append(datagram, ipAddress.To4()...)
} else {
datagram = append(datagram, byte(ipv6))
datagram = append(datagram, ipAddress.To16()...)
}
} else {
if len(host) > 255 {
return nil, errors.New("domain name too long")
}
datagram = append(datagram, byte(domainName), byte(len(host)))
datagram = append(datagram, []byte(host)...)
}
datagram = binary.BigEndian.AppendUint16(datagram, uint16(port))
datagram = append(datagram, payload...)
return datagram, nil
}
func parseSOCKS5UDPDatagram(datagram []byte) (destinationAddress string, payload []byte, err error) {
if len(datagram) < 4 {
return "", nil, errors.New("datagram too short")
}
if datagram[0] != 0 || datagram[1] != 0 {
return "", nil, errors.New("invalid reserved header")
}
if datagram[2] != 0 {
return "", nil, errors.New("fragments are not supported")
}
offset := 3
var host string
switch addrType(datagram[offset]) {
case ipv4:
offset++
if len(datagram) < offset+net.IPv4len+2 {
return "", nil, errors.New("datagram too short for IPv4")
}
host = net.IP(datagram[offset : offset+net.IPv4len]).String()
offset += net.IPv4len
case ipv6:
offset++
if len(datagram) < offset+net.IPv6len+2 {
return "", nil, errors.New("datagram too short for IPv6")
}
host = net.IP(datagram[offset : offset+net.IPv6len]).String()
offset += net.IPv6len
case domainName:
offset++
if len(datagram) < offset+1 {
return "", nil, errors.New("datagram too short for domain length")
}
domainLength := int(datagram[offset])
offset++
if len(datagram) < offset+domainLength+2 {
return "", nil, errors.New("datagram too short for domain")
}
host = string(datagram[offset : offset+domainLength])
offset += domainLength
default:
return "", nil, errors.New("unknown address type")
}
if len(datagram) < offset+2 {
return "", nil, errors.New("datagram too short for port")
}
port := binary.BigEndian.Uint16(datagram[offset : offset+2])
offset += 2
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), datagram[offset:], nil
return conn
}
func Test_newServer(t *testing.T) {
@@ -550,8 +224,7 @@ func Test_Server_StartStop(t *testing.T) {
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
logger.EXPECT().Infof("SOCKS5 TCP server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 UDP server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any())
server := newServer(Settings{
Address: "127.0.0.1:0",
@@ -704,70 +377,6 @@ func Test_decodeRequest(t *testing.T) {
}
}
func Test_udpAssociateExpectedClientEndpoint(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
request request
expected netip.AddrPort
expectedErr string
}{
"ipv4_endpoint": {
request: request{
addressType: ipv4,
destination: "192.0.2.10",
port: 5555,
},
expected: netip.MustParseAddrPort("192.0.2.10:5555"),
},
"ipv4_unspecified_address": {
request: request{
addressType: ipv4,
destination: "0.0.0.0",
port: 6000,
},
expected: netip.AddrPortFrom(netip.Addr{}, 6000),
},
"domain_name_with_port": {
request: request{
addressType: domainName,
destination: "client.example",
port: 7000,
},
expected: netip.AddrPortFrom(netip.Addr{}, 7000),
},
"domain_name_without_port": {
request: request{
addressType: domainName,
destination: "client.example",
},
expected: netip.AddrPort{},
},
"unsupported_address_type": {
request: request{
addressType: 255,
},
expectedErr: "address type 255 is not supported",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
result, err := udpAssociateExpectedClientEndpoint(testCase.request)
if testCase.expectedErr != "" {
assert.ErrorContains(t, err, testCase.expectedErr)
return
}
assert.NoError(t, err)
assert.Equal(t, testCase.expected, result)
})
}
}
func Test_verifyFirstNegotiation(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -989,6 +598,10 @@ func Test_cmdType_String(t *testing.T) {
cmd: connect,
expectedName: "connect",
},
"bind": {
cmd: bind,
expectedName: "bind",
},
"udp_associate": {
cmd: udpAssociate,
expectedName: "UDP associate",
@@ -1007,80 +620,3 @@ func Test_cmdType_String(t *testing.T) {
})
}
}
func Test_socksConn_udpAssociationAddresses(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
routerAddress string
expectAddressFromConn bool
expectedAddress string
}{
"wildcard_router_address_uses_control_connection_local_ip": {
routerAddress: ":0",
expectAddressFromConn: true,
},
"concrete_router_address_is_kept": {
routerAddress: "127.0.0.1:0",
expectedAddress: "127.0.0.1",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
router, err := newUDPRouter(t.Context(), testCase.routerAddress, noopLogger{})
require.NoError(t, err)
t.Cleanup(func() {
err := router.close()
assert.NoError(t, err)
})
controlListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err)
})
acceptedConnCh := make(chan net.Conn, 1)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", controlListener.Addr().String())
require.NoError(t, err)
defer clientControlConn.Close()
serverControlConn := <-acceptedConnCh
defer serverControlConn.Close()
socksConnection := &socksConn{
clientConn: clientControlConn,
udpRouter: router,
}
bindAddress, bindPort, bindAddrType, err := socksConnection.udpAssociationAddresses()
require.NoError(t, err)
if testCase.expectAddressFromConn {
clientLocalHost, _, err := net.SplitHostPort(clientControlConn.LocalAddr().String())
require.NoError(t, err)
assert.Equal(t, clientLocalHost, bindAddress)
} else {
assert.Equal(t, testCase.expectedAddress, bindAddress)
}
_, routerPortString, err := net.SplitHostPort(router.localAddress().String())
require.NoError(t, err)
routerPort, err := strconv.ParseUint(routerPortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint16(routerPort), bindPort)
assert.Equal(t, ipv4, bindAddrType)
})
}
}
-370
View File
@@ -1,370 +0,0 @@
package socks5
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
)
type udpAssociation struct {
id uint64
clientAddrPort netip.AddrPort
expectedAddrPort netip.AddrPort
controlConnAddr netip.Addr
packetCh chan *bytes.Buffer
}
type udpRouter struct {
logger Logger
listener net.PacketConn
mutex sync.Mutex
bufferPool sync.Pool
nextAssociationID uint64
clientAddrPortToAssociation map[netip.AddrPort]udpAssociation
clientIPToPendingAssociations map[netip.Addr][]udpAssociation
associationIDToClientAddrPort map[uint64]netip.AddrPort
}
const (
maxUDPPacketLength = 65535
maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2
pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead
)
func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) {
config := &net.ListenConfig{}
listener, err := config.ListenPacket(ctx, "udp", address)
if err != nil {
return nil, fmt.Errorf("UDP listening: %w", err)
}
return &udpRouter{
logger: logger,
listener: listener,
bufferPool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity))
},
},
nextAssociationID: 1,
clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation),
clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation),
associationIDToClientAddrPort: make(map[uint64]netip.AddrPort),
}, nil
}
func (r *udpRouter) localAddress() net.Addr {
return r.listener.LocalAddr()
}
func (r *udpRouter) close() error {
return r.listener.Close()
}
func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) {
controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String())
if err != nil {
return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err)
}
controlConnAddr := controlConnAddrPort.Addr().Unmap()
r.mutex.Lock()
defer r.mutex.Unlock()
const udpPacketChannelBuffer = 2
associationID := r.nextAssociationID
r.nextAssociationID++
association := udpAssociation{
id: associationID,
expectedAddrPort: expectedAddrPort,
controlConnAddr: controlConnAddr,
packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer),
}
if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 {
association.clientAddrPort = expectedAddrPort
r.clientAddrPortToAssociation[association.clientAddrPort] = association
r.associationIDToClientAddrPort[association.id] = association.clientAddrPort
return association, nil
}
pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr]
pendingAssociations = append(pendingAssociations, association)
r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations
return association, nil
}
func (r *udpRouter) unregisterAssociation(association udpAssociation) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id]
if hasClientAddress {
delete(r.associationIDToClientAddrPort, association.id)
delete(r.clientAddrPortToAssociation, clientAddrPort)
}
pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr]
for i, pendingAssociation := range pendingAssociations {
if pendingAssociation.id == association.id {
pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...)
break
}
}
if len(pendingAssociations) == 0 {
delete(r.clientIPToPendingAssociations, association.controlConnAddr)
} else {
r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations
}
}
func (r *udpRouter) run(ctx context.Context) error {
packetBuffer := make([]byte, maxUDPPacketLength)
for {
packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil && errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("reading UDP packet: %w", err)
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
_, err = buffer.Write(packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("buffering packet: %s", err)
continue
}
err = r.routePacket(sourceAddrPort, buffer)
if err != nil {
r.logger.Warnf("failed routing UDP packet: %s", err)
}
}
}
func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error {
r.mutex.Lock()
association, packetFromClient := r.findClientAssociation(sourceAddrPort)
r.mutex.Unlock()
if !packetFromClient {
r.bufferPool.Put(packet)
return nil
}
select {
case association.packetCh <- packet:
return nil
default:
r.bufferPool.Put(packet)
return errors.New("association packet queue full")
}
}
func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) (
association udpAssociation, ok bool,
) {
association, ok = r.clientAddrPortToAssociation[sourceAddrPort]
if ok {
return association, true
}
sourceAddr := sourceAddrPort.Addr()
pendingAssociations := r.clientIPToPendingAssociations[sourceAddr]
if len(pendingAssociations) == 0 {
return udpAssociation{}, false
}
index := -1
for i, pendingAssociation := range pendingAssociations {
if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) {
association = pendingAssociation
index = i
break
}
}
if index == -1 {
return udpAssociation{}, false
}
r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...)
if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 {
delete(r.clientIPToPendingAssociations, sourceAddr)
}
association.clientAddrPort = sourceAddrPort
r.clientAddrPortToAssociation[sourceAddrPort] = association
r.associationIDToClientAddrPort[association.id] = sourceAddrPort
return association, true
}
func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool {
switch {
case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr():
return false
case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port():
return false
}
return true
}
func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) (
clientAddrPort netip.AddrPort, ok bool,
) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, ok = r.associationIDToClientAddrPort[associationID]
return clientAddrPort, ok
}
func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) {
config := &net.ListenConfig{}
socket, err := config.ListenPacket(ctx, "udp", ":0")
if err != nil {
r.logger.Warnf("creating per-association UDP socket: %s", err)
return
}
defer socket.Close()
go closeSocketOnContextDone(ctx, socket)
packetBuffer := make([]byte, maxUDPPacketLength)
forwardDoneCh := make(chan struct{})
go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh)
for {
packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
<-forwardDoneCh
return
}
r.logger.Warnf("reading from per-association UDP socket: %s", err)
continue
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address from destination: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("encoding response datagram: %s", err)
continue
}
clientAddrPort, found := r.clientAddrPortForAssociation(association.id)
if !found {
r.bufferPool.Put(buffer)
r.logger.Warnf("client address not found for association id %d", association.id)
continue
}
clientUDPAddress := &net.UDPAddr{
IP: clientAddrPort.Addr().AsSlice(),
Port: int(clientAddrPort.Port()),
}
_, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("writing response to client: %s", err)
}
}
}
func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) {
<-ctx.Done()
_ = socket.Close()
}
func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn,
packetCh <-chan *bytes.Buffer, done chan<- struct{},
) {
defer close(done)
for {
select {
case <-ctx.Done():
return
case buffer, ok := <-packetCh:
if !ok {
return
}
err := r.writeClientPacketToDestination(ctx, socket, buffer)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("forwarding client packet to destination: %s", err)
}
}
}
}
func (r *udpRouter) writeClientPacketToDestination(ctx context.Context,
socket net.PacketConn, packet *bytes.Buffer,
) error {
destination, payload, err := decodeUDPDatagram(packet.Bytes())
if err != nil {
return fmt.Errorf("decoding UDP datagram: %w", err)
}
host, portStr, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Errorf("splitting destination host and port: %w", err)
}
if _, err := netip.ParseAddr(host); err != nil { // domain name
addrs, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return fmt.Errorf("resolving destination host: %w", err)
}
if len(addrs) == 0 {
return fmt.Errorf("resolving destination host: no addresses found for %q", host)
}
destination = net.JoinHostPort(addrs[0], portStr)
}
resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination)
if err != nil {
return fmt.Errorf("resolving destination UDP address: %w", err)
}
_, err = socket.WriteTo(payload, resolvedDestinationUDPAddress)
if err != nil && ctx.Err() == nil {
return fmt.Errorf("writing payload to destination: %w", err)
}
return nil
}
func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) {
addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err)
}
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
@@ -1,164 +0,0 @@
//go:build integration
package socks5
import (
"bytes"
"context"
"math/rand/v2"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_udpRouter_ResolveGithubFromCloudflareDNS(t *testing.T) {
t.Parallel()
ctx := t.Context()
var cancel context.CancelFunc
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
const deadlineBuffer = 500 * time.Millisecond
deadline = deadline.Add(-deadlineBuffer)
} else {
const defaultTimeout = 10 * time.Second
deadline = time.Now().Add(defaultTimeout)
}
ctx, cancel = context.WithDeadline(ctx, deadline)
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
router, err := newUDPRouter(ctx, "127.0.0.1:0", logger)
require.NoError(t, err)
routerRunErrCh := make(chan error)
go func() {
routerRunErrCh <- router.run(ctx)
}()
t.Cleanup(func() {
cancel()
err := router.close()
assert.NoError(t, err, "closing router")
runErr := <-routerRunErrCh
assert.NoError(t, runErr)
})
controlListener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err, "closing control listener")
})
acceptedConnCh := make(chan net.Conn)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
assert.NoError(t, acceptErr, "accepting control connection")
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", controlListener.Addr().String())
require.NoError(t, err)
t.Cleanup(func() {
err = clientControlConn.Close()
assert.NoError(t, err, "closing client control connection")
})
serverControlConn := <-acceptedConnCh
t.Cleanup(func() {
err := serverControlConn.Close()
assert.NoError(t, err, "closing server control connection")
})
association, err := router.registerAssociation(serverControlConn, netip.AddrPort{})
require.NoError(t, err)
t.Cleanup(func() {
router.unregisterAssociation(association)
})
associationCtx, associationCancel := context.WithCancel(ctx)
handlerDoneCh := make(chan struct{})
go func() {
router.runAssociationHandler(associationCtx, association)
close(handlerDoneCh)
}()
t.Cleanup(func() {
associationCancel()
<-handlerDoneCh
})
udpRouterAddress, err := net.ResolveUDPAddr("udp", router.localAddress().String())
require.NoError(t, err)
clientUDPConn, err := net.DialUDP("udp", nil, udpRouterAddress)
require.NoError(t, err)
t.Cleanup(func() {
err := clientUDPConn.Close()
assert.NoError(t, err, "closing client UDP connection")
})
queryID := uint16(rand.Uint32()) //nolint:gosec
dnsRequest := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: queryID,
RecursionDesired: true,
},
Question: []dns.Question{{
Name: dns.Fqdn("github.com"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
dnsQuery, err := dnsRequest.Pack()
require.NoError(t, err)
targetAddrPort := netip.MustParseAddrPort("1.1.1.1:53")
socksDatagramBuffer := bytes.NewBuffer(nil)
err = encodeUDPDatagramToBuffer(socksDatagramBuffer, targetAddrPort, dnsQuery)
require.NoError(t, err)
socksDatagram := socksDatagramBuffer.Bytes()
err = clientUDPConn.SetDeadline(deadline)
require.NoError(t, err)
_, err = clientUDPConn.Write(socksDatagram)
require.NoError(t, err)
responseBuffer := make([]byte, maxUDPPacketLength)
responseLength, err := clientUDPConn.Read(responseBuffer)
require.NoError(t, err)
responseDestination, responsePayload, err := decodeUDPDatagram(responseBuffer[:responseLength])
require.NoError(t, err)
responseHost, responsePortString, err := net.SplitHostPort(responseDestination)
require.NoError(t, err)
responsePort, err := strconv.ParseUint(responsePortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint64(53), responsePort)
assert.NotEmpty(t, responseHost)
dnsResponse := new(dns.Msg)
err = dnsResponse.Unpack(responsePayload)
require.NoError(t, err)
assert.Equal(t, queryID, dnsResponse.Id)
assert.True(t, dnsResponse.Response)
assert.Equal(t, dns.RcodeSuccess, dnsResponse.Rcode)
require.NotEmpty(t, dnsResponse.Question)
assert.Equal(t, dns.Fqdn("github.com"), dnsResponse.Question[0].Name)
assert.Equal(t, dns.TypeA, dnsResponse.Question[0].Qtype)
assert.NotEmpty(t, dnsResponse.Answer)
require.NoError(t, err)
}
+19 -11
View File
@@ -1,23 +1,32 @@
package storage
import (
"embed"
"encoding/json"
"fmt"
"path/filepath"
"path"
serversmodule "github.com/qdm12/gluetun-servers/pkg/servers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
func parseHardcodedServers() (allServers models.AllServers) {
allProviders := providers.All()
//go:embed servers.json
var allServersEmbedFS embed.FS
const version = 1
allServers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
allServers.Version = version
for _, provider := range allProviders {
filename := provider + ".json"
func parseHardcodedServers() (allServers models.AllServers) {
f, err := allServersEmbedFS.Open("servers.json")
if err != nil {
panic(err)
}
defer f.Close() // no-op
decoder := json.NewDecoder(f)
err = decoder.Decode(&allServers)
if err != nil {
panic("decoding servers.json: " + err.Error())
}
for provider, metadata := range allServers.ProviderToServers {
filename := path.Base(metadata.Filepath)
providerFile, err := serversmodule.Files.Open(filename)
if err != nil {
panic(fmt.Sprintf("reading embedded provider file %s for %s: %s", filename, provider, err))
@@ -35,8 +44,7 @@ func parseHardcodedServers() (allServers models.AllServers) {
filename, provider))
}
const serversPath = "/gluetun/servers/"
providerServers.Filepath = filepath.Join(serversPath, filename)
providerServers.Filepath = metadata.Filepath // inherit filepath from servers.json
allServers.ProviderToServers[provider] = providerServers
}
+2 -1
View File
@@ -3,5 +3,6 @@ package storage
import "fmt"
func panicOnProviderMissingHardcoded(provider string) {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
panic(fmt.Sprintf("provider %s not found in hardcoded servers map; "+
"did you add the provider key in the embedded servers.json?", provider))
}
+2 -1
View File
@@ -152,7 +152,8 @@ func Test_extractServersFromBytes(t *testing.T) {
allProviders[0]: 1,
// Missing provider allProviders[1]
}
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map; "+
"did you add the provider key in the embedded servers.json?", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcodedVersions)
})
+72
View File
@@ -0,0 +1,72 @@
{
"version": 1,
"airvpn": {
"filepath": "/gluetun/servers/airvpn.json"
},
"cyberghost": {
"filepath": "/gluetun/servers/cyberghost.json"
},
"expressvpn": {
"filepath": "/gluetun/servers/expressvpn.json"
},
"fastestvpn": {
"filepath": "/gluetun/servers/fastestvpn.json"
},
"giganews": {
"filepath": "/gluetun/servers/giganews.json"
},
"hidemyass": {
"filepath": "/gluetun/servers/hidemyass.json"
},
"ipvanish": {
"filepath": "/gluetun/servers/ipvanish.json"
},
"ivpn": {
"filepath": "/gluetun/servers/ivpn.json"
},
"mullvad": {
"filepath": "/gluetun/servers/mullvad.json"
},
"nordvpn": {
"filepath": "/gluetun/servers/nordvpn.json"
},
"perfect privacy": {
"filepath": "/gluetun/servers/perfect privacy.json"
},
"privado": {
"filepath": "/gluetun/servers/privado.json"
},
"private internet access": {
"filepath": "/gluetun/servers/private internet access.json"
},
"privatevpn": {
"filepath": "/gluetun/servers/privatevpn.json"
},
"protonvpn": {
"filepath": "/gluetun/servers/protonvpn.json"
},
"purevpn": {
"filepath": "/gluetun/servers/purevpn.json"
},
"slickvpn": {
"filepath": "/gluetun/servers/slickvpn.json"
},
"surfshark": {
"filepath": "/gluetun/servers/surfshark.json"
},
"torguard": {
"filepath": "/gluetun/servers/torguard.json"
},
"vpn unlimited": {
"filepath": "/gluetun/servers/vpn unlimited.json"
},
"vpnsecure": {
"filepath": "/gluetun/servers/vpnsecure.json"
},
"vyprvpn": {
"filepath": "/gluetun/servers/vyprvpn.json"
},
"windscribe": {
"filepath": "/gluetun/servers/windscribe.json"
}
}