mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-27 22:37:33 +02:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b589b28b8e | |||
| 52a41cb891 | |||
| 6c76273ef6 | |||
| 366062dc12 |
@@ -67,10 +67,6 @@ jobs:
|
||||
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
||||
test-container
|
||||
|
||||
- name: Run integration tests in test container
|
||||
run: |
|
||||
docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet
|
||||
|
||||
- name: Verify dev cross platform compatibility
|
||||
run: docker build --target xcompile .
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ Guidance for coding agents working in this repository.
|
||||
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
|
||||
- Use `netip` types instead of `net` types whenever possible
|
||||
- Use constants instead of variables whenever possible, especially function-local inline constants.
|
||||
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
|
||||
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
|
||||
- `panic`:
|
||||
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
|
||||
@@ -116,7 +115,6 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool.
|
||||
- **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example.
|
||||
- **Always** set the `.Return(...)` on the mock if the function returns something.
|
||||
- Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency
|
||||
- Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests.
|
||||
|
||||
### main.go
|
||||
|
||||
@@ -129,7 +127,6 @@ The Go formatter used is gofumpt.
|
||||
### Errors
|
||||
|
||||
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
|
||||
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
|
||||
- In rare cases, you can just use `return err` notably:
|
||||
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
|
||||
- If the current function only statement is the call to another function, for example:
|
||||
@@ -182,8 +179,6 @@ The Go formatter used is gofumpt.
|
||||
|
||||
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
|
||||
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
|
||||
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
|
||||
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
|
||||
|
||||
## Validation checklist
|
||||
|
||||
|
||||
@@ -5,15 +5,16 @@ go 1.25.0
|
||||
require (
|
||||
github.com/ProtonMail/go-srp v0.0.7
|
||||
github.com/amnezia-vpn/amneziawg-go v0.2.16
|
||||
github.com/breml/rootcerts v0.3.4
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/breml/rootcerts v0.3.5
|
||||
github.com/fatih/color v1.19.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/jsimonetti/rtnetlink v1.4.2
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/klauspost/compress v1.18.6
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/mdlayher/genetlink v1.3.2
|
||||
github.com/mdlayher/netlink v1.9.0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/miekg/dns v1.1.62
|
||||
github.com/pelletier/go-toml/v2 v2.4.0
|
||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20260421173011-9de8e7fdbe3a
|
||||
github.com/qdm12/gluetun-servers v0.1.0
|
||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978
|
||||
@@ -32,7 +33,7 @@ require (
|
||||
golang.org/x/text v0.37.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
gopkg.in/ini.v1 v1.67.1
|
||||
gopkg.in/ini.v1 v1.67.3
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -44,10 +45,9 @@ require (
|
||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mdlayher/socket v0.5.1 // indirect
|
||||
github.com/miekg/dns v1.1.62 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
|
||||
@@ -10,8 +10,8 @@ github.com/amnezia-vpn/amneziawg-go v0.2.16 h1:XY6HOq/xtqH8ZXMncRWkjFs85EKdN10NL
|
||||
github.com/amnezia-vpn/amneziawg-go v0.2.16/go.mod h1:nRkPpIzjCxMW8pZKXTRkpqAQVlmFJdVOGkeQSC7wbms=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/breml/rootcerts v0.3.4 h1:9i7WNl/ctd9OEAOaTfLy//Wrlfxq/tRQ7v4okYFN9Ys=
|
||||
github.com/breml/rootcerts v0.3.4/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/breml/rootcerts v0.3.5 h1:oi7YiZ25HH52+mrKyjrMkcAFfnRDUf6HO8aUDr7RlJI=
|
||||
github.com/breml/rootcerts v0.3.5/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
@@ -25,8 +25,8 @@ github.com/cronokirby/saferith v0.33.0/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w=
|
||||
github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
@@ -35,17 +35,16 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90=
|
||||
github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao=
|
||||
github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU=
|
||||
github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||
@@ -60,8 +59,8 @@ github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE9
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.4.0 h1:Mwu0mAkUKbittDs3/ADDWXqMmq3EOK2VHiuCkV00Row=
|
||||
github.com/pelletier/go-toml/v2 v2.4.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -153,7 +152,6 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
@@ -192,8 +190,8 @@ google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojt
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k=
|
||||
gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss=
|
||||
gopkg.in/ini.v1 v1.67.3 h1:iM9Lhz5MRSGhHVGGwCuzG9KO8PoirCXj/m/qTmOJJQw=
|
||||
gopkg.in/ini.v1 v1.67.3/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -36,6 +36,7 @@ func streamLines(done chan<- struct{}, logger Logger,
|
||||
case line, ok := <-stdout:
|
||||
if ok {
|
||||
logger.Info(line)
|
||||
break
|
||||
}
|
||||
if stderr == nil {
|
||||
return
|
||||
@@ -44,6 +45,7 @@ func streamLines(done chan<- struct{}, logger Logger,
|
||||
case line, ok := <-stderr:
|
||||
if ok {
|
||||
logger.Error(line)
|
||||
break
|
||||
}
|
||||
if stdout == nil {
|
||||
return
|
||||
|
||||
@@ -28,8 +28,6 @@ type firewallImpl interface { //nolint:interfacebloat
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
ip netip.Addr, port uint16, remove bool) error
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
|
||||
source, destination netip.AddrPort, remove bool) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
|
||||
@@ -2,7 +2,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@@ -178,29 +177,6 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
if source.Addr().BitLen() != destination.Addr().BitLen() {
|
||||
return errors.New("source and destination address families do not match")
|
||||
}
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(),
|
||||
protocol, protocol, source.Port(), destination.Port())
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
|
||||
@@ -25,10 +25,3 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
) error {
|
||||
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
|
||||
source, destination, remove)
|
||||
}
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
)
|
||||
|
||||
// Client is a client for making restricted network requests,
|
||||
// such as opening temporary firewall rules for HTTPS connections.
|
||||
// It is not meant to be high performance, although it can be used for
|
||||
// multiple requests and concurrently.
|
||||
type Client struct {
|
||||
outboundInterface string
|
||||
ipv6Supported bool
|
||||
firewall Firewall
|
||||
dohServers []provider.DoHServer
|
||||
}
|
||||
|
||||
func New(settings Settings) *Client {
|
||||
if err := settings.validate(); err != nil {
|
||||
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
|
||||
}
|
||||
dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers))
|
||||
for i, upstreamResolver := range settings.UpstreamResolvers {
|
||||
dohServers[i] = upstreamResolver.DoH
|
||||
}
|
||||
|
||||
return &Client{
|
||||
outboundInterface: settings.DefaultInterface,
|
||||
ipv6Supported: *settings.IPv6Supported,
|
||||
firewall: settings.Firewall,
|
||||
dohServers: dohServers,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenHTTPSByHostname opens an https connection through the firewall,
|
||||
// to the hostname which in the format `host:port`. The returned cleanup
|
||||
// function must be called to remove the temporary firewall rule and close connections.
|
||||
// It first resolves the domain in hostname using DNS over HTTPS and then opens
|
||||
// the restricted HTTPS connection to the resolved IP.
|
||||
func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) (
|
||||
httpClient *http.Client, cleanup func() error, err error,
|
||||
) {
|
||||
host, portStr, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
|
||||
}
|
||||
resolvedIPs, err := c.ResolveName(ctx, host)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
||||
} else if len(resolvedIPs) == 0 {
|
||||
return nil, nil, fmt.Errorf("no IP address found for name %q", host)
|
||||
}
|
||||
|
||||
portUint, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parsing port: %w", err)
|
||||
} else if portUint == 0 {
|
||||
return nil, nil, errors.New("destination port cannot be 0")
|
||||
}
|
||||
port := uint16(portUint)
|
||||
|
||||
errs := make([]error, 0, len(resolvedIPs))
|
||||
for _, ip := range resolvedIPs {
|
||||
addrPort := netip.AddrPortFrom(ip, port)
|
||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
|
||||
continue
|
||||
}
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
func ptrTo[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||
// The returned [*http.Client] must be used sequentially only, and each request must
|
||||
// have its response body fully read/discarded and then closed.
|
||||
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
|
||||
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||
}
|
||||
|
||||
const remove = false
|
||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
||||
}
|
||||
|
||||
connection, err := connectSourceConnection(ctx, fd, destinationAddrPort)
|
||||
if err != nil {
|
||||
const remove = true
|
||||
_ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
|
||||
}
|
||||
|
||||
dial := makeDial(connection, destinationTLSName)
|
||||
httpClient = newHTTPSClient(destinationTLSName, dial)
|
||||
cleanup = func() error {
|
||||
var errs []error
|
||||
httpClient.CloseIdleConnections()
|
||||
err := connection.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("closing connection: %w", err))
|
||||
}
|
||||
const remove = true
|
||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
|
||||
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
|
||||
const timeout = 5 * time.Second
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
MaxConnsPerHost: 1,
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: destinationTLSName,
|
||||
},
|
||||
DialContext: dial,
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func makeDial(connection net.Conn, tlsName string) dialFunc {
|
||||
_, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String())
|
||||
if err != nil {
|
||||
panic(err) // connection remote address should always be in the form "host:port"
|
||||
}
|
||||
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
|
||||
used := false
|
||||
return func(_ context.Context, network, address string) (net.Conn, error) {
|
||||
if used {
|
||||
return nil, errors.New("dial function called more than once")
|
||||
}
|
||||
used = true
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected dial network %q", network)
|
||||
}
|
||||
if address != expectedAddress {
|
||||
return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress)
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
|
||||
sourceIP, err := sourceIPForDestination(destinationIP)
|
||||
if err != nil {
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err)
|
||||
}
|
||||
|
||||
family := constants.AF_INET
|
||||
if sourceIP.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
fd, err = newTCPSockStream(family)
|
||||
if err != nil {
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err)
|
||||
}
|
||||
|
||||
bindAddrPort := netip.AddrPortFrom(sourceIP, 0)
|
||||
err = bindFD(fd, bindAddrPort)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err)
|
||||
}
|
||||
|
||||
sourceAddr, err = fdToSourceAddr(fd)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err)
|
||||
}
|
||||
|
||||
return fd, sourceAddr, nil
|
||||
}
|
||||
|
||||
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
|
||||
connection net.Conn, err error,
|
||||
) {
|
||||
err = connectFD(ctx, fd, destinationAddrPort)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return nil, fmt.Errorf("connecting socket: %w", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
if file == nil {
|
||||
closeFD(fd)
|
||||
return nil, fmt.Errorf("creating socket file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
connection, err = net.FileConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wrapping socket connection: %w", err)
|
||||
}
|
||||
|
||||
return connection, nil
|
||||
}
|
||||
|
||||
func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
family := uint8(constants.AF_INET)
|
||||
if destinationIP.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
requestMessage := &rtnetlink.RouteMessage{
|
||||
Family: family,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
Dst: destinationIP.AsSlice(),
|
||||
},
|
||||
}
|
||||
messages, err := conn.Route.Get(requestMessage)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err)
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Attributes.Src == nil {
|
||||
continue
|
||||
}
|
||||
if message.Attributes.Src.To4() == nil {
|
||||
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP)
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type listenAddrPortMatcher struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) Matches(x any) bool {
|
||||
ip, ok := x.(netip.AddrPort)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expected.IsValid() {
|
||||
return ip == m.expected
|
||||
}
|
||||
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) String() string {
|
||||
if m.expected.IsValid() {
|
||||
return "is the same as " + m.expected.String()
|
||||
}
|
||||
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
||||
}
|
||||
|
||||
type destinationAddrPortMatcher struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (m destinationAddrPortMatcher) Matches(x any) bool {
|
||||
ip, ok := x.(netip.AddrPort)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expected.IsValid() {
|
||||
return ip == m.expected
|
||||
}
|
||||
return ip.IsValid() && ip.Port() == m.expected.Port()
|
||||
}
|
||||
|
||||
func (m destinationAddrPortMatcher) String() string {
|
||||
if m.expected.IsValid() {
|
||||
return "is the same as " + m.expected.String()
|
||||
}
|
||||
return "matches the port " + fmt.Sprint(m.expected.Port())
|
||||
}
|
||||
|
||||
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
const destinationTLSName = "one.one.one.one"
|
||||
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
sourceMatcher := listenAddrPortMatcher{}
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
|
||||
).DoAndReturn(func(_ context.Context,
|
||||
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||
) error {
|
||||
sourceMatcher.expected = source
|
||||
return nil
|
||||
})
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
|
||||
).Return(nil)
|
||||
|
||||
const ipv6Supported = false
|
||||
upstreamResolvers := []provider.Provider{provider.Google()}
|
||||
settings := Settings{
|
||||
Firewall: firewall,
|
||||
DefaultInterface: "eth0",
|
||||
IPv6Supported: ptrTo(ipv6Supported),
|
||||
UpstreamResolvers: upstreamResolvers,
|
||||
}
|
||||
client := New(settings)
|
||||
|
||||
httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, httpClient)
|
||||
require.NotNil(t, cleanup)
|
||||
|
||||
const requests = 2
|
||||
|
||||
for range requests {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
require.NoError(t, err)
|
||||
_, err = io.Copy(io.Discard, response.Body)
|
||||
require.NoError(t, err)
|
||||
err = response.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
err = cleanup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Firewall interface {
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||
@@ -1,50 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
|
||||
|
||||
// Package restrictednet is a generated GoMock package.
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
context "context"
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFirewall is a mock of Firewall interface.
|
||||
type MockFirewall struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFirewallMockRecorder
|
||||
}
|
||||
|
||||
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||
type MockFirewallMockRecorder struct {
|
||||
mock *MockFirewall
|
||||
}
|
||||
|
||||
// NewMockFirewall creates a new mock instance.
|
||||
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||
mock := &MockFirewall{ctrl: ctrl}
|
||||
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort mocks base method.
|
||||
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
|
||||
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
@@ -1,205 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ResolveName resolves the given host name to IP addresses using DoH servers,
|
||||
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
|
||||
// The host must be a single well-formed domain name, without port or path.
|
||||
func (c *Client) ResolveName(ctx context.Context, host string) (
|
||||
resolvedAddresses []netip.Addr, err error,
|
||||
) {
|
||||
const maxTypes = 2
|
||||
questionTypes := make([]uint16, 0, maxTypes)
|
||||
if c.ipv6Supported {
|
||||
questionTypes = append(questionTypes, dns.TypeAAAA)
|
||||
}
|
||||
questionTypes = append(questionTypes, dns.TypeA)
|
||||
|
||||
var addresses []netip.Addr
|
||||
errs := make([]error, 0, len(questionTypes))
|
||||
for _, questionType := range questionTypes {
|
||||
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
addresses = append(addresses, answerAddresses...)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(addresses) > 0:
|
||||
return addresses, nil
|
||||
case len(errs) == 0:
|
||||
return nil, nil // no address found
|
||||
default: // errors
|
||||
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resolveOneQuestionType(ctx context.Context,
|
||||
host string, questionType uint16,
|
||||
) (addresses []netip.Addr, err error) {
|
||||
queryMessage := &dns.Msg{}
|
||||
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
|
||||
queryWire, err := queryMessage.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("packing DNS query: %w", err)
|
||||
}
|
||||
|
||||
// Try every DoH server and every of each of their IP until we get a non-empty
|
||||
// successful response.
|
||||
errs := make([]error, 0)
|
||||
for _, dohServer := range c.dohServers {
|
||||
dohURL, err := url.Parse(dohServer.URL)
|
||||
if err != nil {
|
||||
errs = append(errs,
|
||||
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
|
||||
continue
|
||||
}
|
||||
|
||||
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
|
||||
if c.ipv6Supported {
|
||||
// Prefer IPv6 addresses if IPv6 is supported
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
|
||||
}
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
||||
|
||||
for _, dohServerIP := range dohServerIPs {
|
||||
const defaultDoHPort uint16 = 443
|
||||
port := defaultDoHPort
|
||||
if portStr := dohURL.Port(); portStr != "" {
|
||||
port, err = parseDestinationPort(portStr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port)
|
||||
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
|
||||
switch {
|
||||
case err != nil:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
|
||||
dohServer.URL, dohServerAddrPort, err))
|
||||
continue
|
||||
case responseMessage.Rcode != dns.RcodeSuccess:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
|
||||
dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
|
||||
continue
|
||||
}
|
||||
addresses := answersToNetipAddrs(responseMessage)
|
||||
if len(addresses) == 0 {
|
||||
continue
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("resolving %s %s: %w",
|
||||
dns.TypeToString[questionType], host, errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||
dohURL *url.URL, dohServerAddrPort netip.AddrPort,
|
||||
) (responseMessage *dns.Msg, err error) {
|
||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := cleanup()
|
||||
if err == nil && closeErr != nil {
|
||||
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
requestBody := bytes.NewReader(queryWire)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/dns-message")
|
||||
request.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responseData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
_ = response.Body.Close()
|
||||
return nil, fmt.Errorf("reading response body: %w", err)
|
||||
}
|
||||
|
||||
err = response.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("closing response body: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("response status code is %s (data length %d)",
|
||||
response.Status, len(responseData))
|
||||
}
|
||||
|
||||
responseMessage = new(dns.Msg)
|
||||
err = responseMessage.Unpack(responseData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing DoH response: %w", err)
|
||||
}
|
||||
|
||||
return responseMessage, nil
|
||||
}
|
||||
|
||||
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
addresses = make([]netip.Addr, 0, len(message.Answer))
|
||||
for _, answer := range message.Answer {
|
||||
switch record := answer.(type) {
|
||||
case *dns.A:
|
||||
address, ok := netip.AddrFromSlice(record.A)
|
||||
if ok {
|
||||
addresses = append(addresses, address.Unmap())
|
||||
}
|
||||
case *dns.AAAA:
|
||||
address, ok := netip.AddrFromSlice(record.AAAA)
|
||||
if ok {
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
func parseDestinationPort(portStr string) (port uint16, err error) {
|
||||
portUint, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
const maxPortUint = 65535
|
||||
switch {
|
||||
case portUint == 0:
|
||||
return 0, errors.New("port cannot be 0")
|
||||
case portUint > maxPortUint:
|
||||
return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint)
|
||||
}
|
||||
return uint16(portUint), nil
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Client_ResolveName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
sourceMatcher := listenAddrPortMatcher{}
|
||||
destinationMatcher := destinationAddrPortMatcher{
|
||||
expected: netip.AddrPortFrom(netip.Addr{}, 443),
|
||||
}
|
||||
|
||||
// Add rule
|
||||
firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
|
||||
).DoAndReturn(func(
|
||||
_ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
|
||||
) error {
|
||||
sourceMatcher.expected = source
|
||||
destinationMatcher.expected = destination
|
||||
return nil
|
||||
})
|
||||
|
||||
// Removal rule
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
|
||||
).Return(nil).After(firstCall)
|
||||
|
||||
settings := Settings{
|
||||
DefaultInterface: "eth0",
|
||||
IPv6Supported: ptrTo(false),
|
||||
Firewall: firewall,
|
||||
UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
|
||||
}
|
||||
client := New(settings)
|
||||
|
||||
addresses, err := client.ResolveName(ctx, "github.com")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, addresses)
|
||||
}
|
||||
|
||||
func Test_answersToNetipAddrs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
message *dns.Msg
|
||||
expected []netip.Addr
|
||||
}{
|
||||
"nil_message": {},
|
||||
"no_answers": {
|
||||
message: &dns.Msg{},
|
||||
expected: []netip.Addr{},
|
||||
},
|
||||
"a_record": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||
},
|
||||
"aaaa_record": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
},
|
||||
"mixed_records": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testCase := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addresses := answersToNetipAddrs(testCase.message)
|
||||
assert.Equal(t, testCase.expected, addresses)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
DefaultInterface string
|
||||
IPv6Supported *bool
|
||||
Firewall Firewall
|
||||
UpstreamResolvers []provider.Provider
|
||||
}
|
||||
|
||||
func (s *Settings) validate() error {
|
||||
switch {
|
||||
case s.DefaultInterface == "":
|
||||
return errors.New("default interface is not set")
|
||||
case s.IPv6Supported == nil:
|
||||
return errors.New("IPv6 support field is not set")
|
||||
case s.Firewall == nil:
|
||||
return errors.New("firewall is not set")
|
||||
case len(s.UpstreamResolvers) == 0:
|
||||
return errors.New("no upstream resolvers provided")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func closeFD(fd int) {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
func newTCPSockStream(family int) (fd int, err error) {
|
||||
fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return 0, err
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func bindFD(fd int, address netip.AddrPort) error {
|
||||
bindAddr := makeSockAddr(address)
|
||||
return unix.Bind(fd, bindAddr)
|
||||
}
|
||||
|
||||
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||
err := unix.Connect(fd, makeSockAddr(destination))
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case !errors.Is(err, unix.EINPROGRESS):
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
bitsIndex := fd / 64 //nolint:mnd
|
||||
if bitsIndex >= len(unix.FdSet{}.Bits) {
|
||||
return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd)
|
||||
}
|
||||
wset := &unix.FdSet{}
|
||||
wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
|
||||
eset := &unix.FdSet{}
|
||||
eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
|
||||
const selectTimeout = 50 * time.Millisecond
|
||||
timeval := unix.NsecToTimeval(int64(selectTimeout))
|
||||
|
||||
// Wait for the FD to become writable or hit an error state
|
||||
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue // Syscall interrupted, try again
|
||||
}
|
||||
return fmt.Errorf("select error: %w", err)
|
||||
} else if n == 0 {
|
||||
continue // no status change yet
|
||||
}
|
||||
|
||||
// Check if the socket encountered an error
|
||||
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getsockopt error: %w", err)
|
||||
} else if n != 0 {
|
||||
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||
sockAddr, err := unix.Getsockname(fd)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err)
|
||||
}
|
||||
|
||||
sourceAddrPort, err = sockAddrToAddrPort(sockAddr)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return sourceAddrPort, nil
|
||||
}
|
||||
|
||||
func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr {
|
||||
if addressPort.Addr().Is4() {
|
||||
return &unix.SockaddrInet4{
|
||||
Port: int(addressPort.Port()),
|
||||
Addr: addressPort.Addr().As4(),
|
||||
}
|
||||
}
|
||||
return &unix.SockaddrInet6{
|
||||
Port: int(addressPort.Port()),
|
||||
Addr: addressPort.Addr().As16(),
|
||||
}
|
||||
}
|
||||
|
||||
func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) {
|
||||
switch typedSockAddr := sockAddr.(type) {
|
||||
case *unix.SockaddrInet4:
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
|
||||
case *unix.SockaddrInet6:
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
|
||||
default:
|
||||
return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr)
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func closeFD(fd int) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func newTCPSockStream(family int) (fd int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func bindFD(fd int, address netip.AddrPort) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
//go:build integration
|
||||
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Server_UDPResolution(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
|
||||
server := newServer(Settings{
|
||||
Address: "127.0.0.1:0",
|
||||
Logger: noopLogger{},
|
||||
})
|
||||
runErr, err := server.Start(ctx)
|
||||
require.NoError(t, err, "starting SOCKS5 server")
|
||||
|
||||
const timeout = 3 * time.Second
|
||||
|
||||
// Connect to the SOCKS5 server via TCP to negotiate UDP associate
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
tcpConn, err := dialer.DialContext(ctx, "tcp", server.listeningAddress().String())
|
||||
require.NoError(t, err, "tcp connecting to SOCKS5 server")
|
||||
t.Cleanup(func() { tcpConn.Close() })
|
||||
|
||||
negotiateSOCKS5(t, tcpConn, "", "")
|
||||
|
||||
// UDP Associate Command: [VERSION (5), CMD (3 = UDP ASSOC), RSV (0), ATYP (1 = IPv4), ADDR (0.0.0.0), PORT (0)]
|
||||
_, err = tcpConn.Write([]byte{5, 3, 0, 1, 0, 0, 0, 0, 0, 0})
|
||||
require.NoError(t, err, "sending UDP ASSOC request")
|
||||
|
||||
relayAddressString, err := readSOCKS5ResponseAddress(t, tcpConn)
|
||||
require.NoError(t, err, "reading UDP ASSOC reply")
|
||||
relayAddress, err := net.ResolveUDPAddr("udp", relayAddressString)
|
||||
require.NoError(t, err, "resolving udp relay address")
|
||||
|
||||
// Dial the relay using IPv4 so source IP family matches the control connection.
|
||||
udpConn, err := net.DialUDP("udp4", nil, relayAddress)
|
||||
require.NoError(t, err, "dialing UDP relay")
|
||||
t.Cleanup(func() { _ = udpConn.Close() })
|
||||
|
||||
queryID := uint16(rand.Uint32()) //nolint:gosec
|
||||
dnsRequest := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: queryID,
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn("github.com"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
dnsQuery, err := dnsRequest.Pack()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encapsulate DNS payload into SOCKS5 UDP Request Header
|
||||
// [RSV (0,0), FRAG (0), ATYP (1 = IPv4), DST.ADDR (1.1.1.1), DST.PORT (53)]
|
||||
packet := append([]byte{0, 0, 0, 1, 1, 1, 1, 1, 0, 53}, dnsQuery...)
|
||||
|
||||
// Send encapsulated packet to the proxy's UDP relay address
|
||||
_, err = udpConn.Write(packet)
|
||||
require.NoError(t, err, "sending UDP packet to relay")
|
||||
|
||||
// Read response from the proxy relay
|
||||
err = udpConn.SetReadDeadline(time.Now().Add(timeout))
|
||||
require.NoError(t, err, "setting read deadline on UDP connection")
|
||||
buffer := make([]byte, 2048)
|
||||
n, err := udpConn.Read(buffer)
|
||||
require.NoError(t, err, "receiving UDP response from relay")
|
||||
const minimumHeaderSize = 10
|
||||
require.GreaterOrEqual(t, n, minimumHeaderSize, "received UDP packet too short to contain valid SOCKS5 header")
|
||||
|
||||
// Verify header layout and slice out the raw DNS response
|
||||
// Header format: RSV(2) FRAG(1) ATYP(1) DST.ADDR(variable) DST.PORT(2)
|
||||
atyp := buffer[3]
|
||||
var headerSize int
|
||||
switch atyp {
|
||||
case 1: // IPv4
|
||||
headerSize = 10
|
||||
case 3: // Domain name
|
||||
headerSize = 4 + 1 + int(buffer[4]) + 2
|
||||
case 4: // IPv6
|
||||
headerSize = 22
|
||||
default:
|
||||
t.Fatalf("Unknown ATYP in SOCKS5 UDP header: %d", atyp)
|
||||
}
|
||||
|
||||
dnsResponse := new(dns.Msg)
|
||||
err = dnsResponse.Unpack(buffer[headerSize:n])
|
||||
require.NoError(t, err, "unpacking DNS response from SOCKS5 UDP packet")
|
||||
|
||||
assert.Equal(t, queryID, dnsResponse.Id, "DNS response ID should match query ID")
|
||||
|
||||
select {
|
||||
case err := <-runErr:
|
||||
require.NoError(t, err, "SOCKS5 server run error")
|
||||
default:
|
||||
}
|
||||
|
||||
err = server.Stop()
|
||||
require.NoError(t, err, "stopping SOCKS5 server")
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort n
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
const udpPacketChannelBuffer = 2
|
||||
const udpPacketChannelBuffer = 64
|
||||
associationID := r.nextAssociationID
|
||||
r.nextAssociationID++
|
||||
|
||||
|
||||
Reference in New Issue
Block a user