mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-27 22:37:33 +02:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 106a4fdf58 | |||
| f6b2612923 | |||
| 08dfd73367 | |||
| b44c671217 | |||
| 70d80f7473 | |||
| 9af6aaff27 | |||
| d28744e06d | |||
| 69b4e5c584 | |||
| 29186feccc | |||
| b5366b9e44 | |||
| dd07205b85 | |||
| e2256dd1b2 | |||
| 8da913d7c6 | |||
| 2d2c371303 | |||
| b48ba8cb0a | |||
| c18c54c3b7 | |||
| 820689cc23 | |||
| a9a36644ec | |||
| fad8c9889a | |||
| aa781c6cc5 |
@@ -67,6 +67,10 @@ jobs:
|
||||
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
|
||||
test-container
|
||||
|
||||
- name: Run integration tests in test container
|
||||
run: |
|
||||
docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet
|
||||
|
||||
- name: Verify dev cross platform compatibility
|
||||
run: docker build --target xcompile .
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
|
||||
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
|
||||
- Use `netip` types instead of `net` types whenever possible
|
||||
- Use constants instead of variables whenever possible, especially function-local inline constants.
|
||||
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
|
||||
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
|
||||
- `panic`:
|
||||
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
|
||||
@@ -115,6 +116,7 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool.
|
||||
- **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example.
|
||||
- **Always** set the `.Return(...)` on the mock if the function returns something.
|
||||
- Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency
|
||||
- Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests.
|
||||
|
||||
### main.go
|
||||
|
||||
@@ -127,6 +129,7 @@ The Go formatter used is gofumpt.
|
||||
### Errors
|
||||
|
||||
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
|
||||
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
|
||||
- In rare cases, you can just use `return err` notably:
|
||||
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
|
||||
- If the current function only statement is the call to another function, for example:
|
||||
@@ -179,6 +182,8 @@ The Go formatter used is gofumpt.
|
||||
|
||||
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
|
||||
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
|
||||
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
|
||||
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
|
||||
|
||||
## Validation checklist
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
ip netip.Addr, port uint16, remove bool) error
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
|
||||
source, destination netip.AddrPort, remove bool) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
|
||||
@@ -2,6 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@@ -177,6 +178,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
if source.Addr().BitLen() != destination.Addr().BitLen() {
|
||||
return errors.New("source and destination address families do not match")
|
||||
}
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(),
|
||||
protocol, protocol, source.Port(), destination.Port())
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
|
||||
@@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
) error {
|
||||
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
|
||||
source, destination, remove)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
)
|
||||
|
||||
// Client is a client for making restricted network requests,
|
||||
// such as opening temporary firewall rules for HTTPS connections.
|
||||
// It is not meant to be high performance, although it can be used for
|
||||
// multiple requests and concurrently.
|
||||
type Client struct {
|
||||
outboundInterface string
|
||||
ipv6Supported bool
|
||||
firewall Firewall
|
||||
dohServers []provider.DoHServer
|
||||
}
|
||||
|
||||
func New(settings Settings) *Client {
|
||||
if err := settings.validate(); err != nil {
|
||||
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
|
||||
}
|
||||
dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers))
|
||||
for i, upstreamResolver := range settings.UpstreamResolvers {
|
||||
dohServers[i] = upstreamResolver.DoH
|
||||
}
|
||||
|
||||
return &Client{
|
||||
outboundInterface: settings.DefaultInterface,
|
||||
ipv6Supported: *settings.IPv6Supported,
|
||||
firewall: settings.Firewall,
|
||||
dohServers: dohServers,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenHTTPSByHostname opens an https connection through the firewall,
|
||||
// to the hostname which in the format `host:port`. The returned cleanup
|
||||
// function must be called to remove the temporary firewall rule and close connections.
|
||||
// It first resolves the domain in hostname using DNS over HTTPS and then opens
|
||||
// the restricted HTTPS connection to the resolved IP.
|
||||
func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) (
|
||||
httpClient *http.Client, cleanup func() error, err error,
|
||||
) {
|
||||
host, portStr, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
|
||||
}
|
||||
resolvedIPs, err := c.ResolveName(ctx, host)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
||||
} else if len(resolvedIPs) == 0 {
|
||||
return nil, nil, fmt.Errorf("no IP address found for name %q", host)
|
||||
}
|
||||
|
||||
portUint, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parsing port: %w", err)
|
||||
} else if portUint == 0 {
|
||||
return nil, nil, errors.New("destination port cannot be 0")
|
||||
}
|
||||
port := uint16(portUint)
|
||||
|
||||
errs := make([]error, 0, len(resolvedIPs))
|
||||
for _, ip := range resolvedIPs {
|
||||
addrPort := netip.AddrPortFrom(ip, port)
|
||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
|
||||
continue
|
||||
}
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
func ptrTo[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jsimonetti/rtnetlink"
|
||||
"github.com/qdm12/gluetun/internal/pmtud/constants"
|
||||
)
|
||||
|
||||
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||
// The returned [*http.Client] must be used sequentially only, and each request must
|
||||
// have its response body fully read/discarded and then closed.
|
||||
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
|
||||
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||
}
|
||||
|
||||
const remove = false
|
||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
||||
}
|
||||
|
||||
connection, err := connectSourceConnection(ctx, fd, destinationAddrPort)
|
||||
if err != nil {
|
||||
const remove = true
|
||||
_ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
|
||||
}
|
||||
|
||||
dial := makeDial(connection, destinationTLSName)
|
||||
httpClient = newHTTPSClient(destinationTLSName, dial)
|
||||
cleanup = func() error {
|
||||
var errs []error
|
||||
httpClient.CloseIdleConnections()
|
||||
err := connection.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("closing connection: %w", err))
|
||||
}
|
||||
const remove = true
|
||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
|
||||
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
|
||||
const timeout = 5 * time.Second
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
MaxConnsPerHost: 1,
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: destinationTLSName,
|
||||
},
|
||||
DialContext: dial,
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func makeDial(connection net.Conn, tlsName string) dialFunc {
|
||||
_, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String())
|
||||
if err != nil {
|
||||
panic(err) // connection remote address should always be in the form "host:port"
|
||||
}
|
||||
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
|
||||
used := false
|
||||
return func(_ context.Context, network, address string) (net.Conn, error) {
|
||||
if used {
|
||||
return nil, errors.New("dial function called more than once")
|
||||
}
|
||||
used = true
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected dial network %q", network)
|
||||
}
|
||||
if address != expectedAddress {
|
||||
return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress)
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
|
||||
sourceIP, err := sourceIPForDestination(destinationIP)
|
||||
if err != nil {
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err)
|
||||
}
|
||||
|
||||
family := constants.AF_INET
|
||||
if sourceIP.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
fd, err = newTCPSockStream(family)
|
||||
if err != nil {
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err)
|
||||
}
|
||||
|
||||
bindAddrPort := netip.AddrPortFrom(sourceIP, 0)
|
||||
err = bindFD(fd, bindAddrPort)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err)
|
||||
}
|
||||
|
||||
sourceAddr, err = fdToSourceAddr(fd)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err)
|
||||
}
|
||||
|
||||
return fd, sourceAddr, nil
|
||||
}
|
||||
|
||||
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
|
||||
connection net.Conn, err error,
|
||||
) {
|
||||
err = connectFD(ctx, fd, destinationAddrPort)
|
||||
if err != nil {
|
||||
closeFD(fd)
|
||||
return nil, fmt.Errorf("connecting socket: %w", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
if file == nil {
|
||||
closeFD(fd)
|
||||
return nil, fmt.Errorf("creating socket file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
connection, err = net.FileConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wrapping socket connection: %w", err)
|
||||
}
|
||||
|
||||
return connection, nil
|
||||
}
|
||||
|
||||
func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) {
|
||||
conn, err := rtnetlink.Dial(nil)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
family := uint8(constants.AF_INET)
|
||||
if destinationIP.Is6() {
|
||||
family = constants.AF_INET6
|
||||
}
|
||||
|
||||
requestMessage := &rtnetlink.RouteMessage{
|
||||
Family: family,
|
||||
Attributes: rtnetlink.RouteAttributes{
|
||||
Dst: destinationIP.AsSlice(),
|
||||
},
|
||||
}
|
||||
messages, err := conn.Route.Get(requestMessage)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err)
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Attributes.Src == nil {
|
||||
continue
|
||||
}
|
||||
if message.Attributes.Src.To4() == nil {
|
||||
return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP)
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type listenAddrPortMatcher struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) Matches(x any) bool {
|
||||
ip, ok := x.(netip.AddrPort)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expected.IsValid() {
|
||||
return ip == m.expected
|
||||
}
|
||||
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) String() string {
|
||||
if m.expected.IsValid() {
|
||||
return "is the same as " + m.expected.String()
|
||||
}
|
||||
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
||||
}
|
||||
|
||||
type destinationAddrPortMatcher struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (m destinationAddrPortMatcher) Matches(x any) bool {
|
||||
ip, ok := x.(netip.AddrPort)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expected.IsValid() {
|
||||
return ip == m.expected
|
||||
}
|
||||
return ip.IsValid() && ip.Port() == m.expected.Port()
|
||||
}
|
||||
|
||||
func (m destinationAddrPortMatcher) String() string {
|
||||
if m.expected.IsValid() {
|
||||
return "is the same as " + m.expected.String()
|
||||
}
|
||||
return "matches the port " + fmt.Sprint(m.expected.Port())
|
||||
}
|
||||
|
||||
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
const destinationTLSName = "one.one.one.one"
|
||||
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
sourceMatcher := listenAddrPortMatcher{}
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
|
||||
).DoAndReturn(func(_ context.Context,
|
||||
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||
) error {
|
||||
sourceMatcher.expected = source
|
||||
return nil
|
||||
})
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
|
||||
).Return(nil)
|
||||
|
||||
const ipv6Supported = false
|
||||
upstreamResolvers := []provider.Provider{provider.Google()}
|
||||
settings := Settings{
|
||||
Firewall: firewall,
|
||||
DefaultInterface: "eth0",
|
||||
IPv6Supported: ptrTo(ipv6Supported),
|
||||
UpstreamResolvers: upstreamResolvers,
|
||||
}
|
||||
client := New(settings)
|
||||
|
||||
httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, httpClient)
|
||||
require.NotNil(t, cleanup)
|
||||
|
||||
const requests = 2
|
||||
|
||||
for range requests {
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
require.NoError(t, err)
|
||||
_, err = io.Copy(io.Discard, response.Body)
|
||||
require.NoError(t, err)
|
||||
err = response.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
err = cleanup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Firewall interface {
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package restrictednet
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||
@@ -0,0 +1,50 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
|
||||
|
||||
// Package restrictednet is a generated GoMock package.
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
context "context"
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFirewall is a mock of Firewall interface.
|
||||
type MockFirewall struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFirewallMockRecorder
|
||||
}
|
||||
|
||||
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||
type MockFirewallMockRecorder struct {
|
||||
mock *MockFirewall
|
||||
}
|
||||
|
||||
// NewMockFirewall creates a new mock instance.
|
||||
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||
mock := &MockFirewall{ctrl: ctrl}
|
||||
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort mocks base method.
|
||||
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
|
||||
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ResolveName resolves the given host name to IP addresses using DoH servers,
|
||||
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
|
||||
// The host must be a single well-formed domain name, without port or path.
|
||||
func (c *Client) ResolveName(ctx context.Context, host string) (
|
||||
resolvedAddresses []netip.Addr, err error,
|
||||
) {
|
||||
const maxTypes = 2
|
||||
questionTypes := make([]uint16, 0, maxTypes)
|
||||
if c.ipv6Supported {
|
||||
questionTypes = append(questionTypes, dns.TypeAAAA)
|
||||
}
|
||||
questionTypes = append(questionTypes, dns.TypeA)
|
||||
|
||||
var addresses []netip.Addr
|
||||
errs := make([]error, 0, len(questionTypes))
|
||||
for _, questionType := range questionTypes {
|
||||
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
addresses = append(addresses, answerAddresses...)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(addresses) > 0:
|
||||
return addresses, nil
|
||||
case len(errs) == 0:
|
||||
return nil, nil // no address found
|
||||
default: // errors
|
||||
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resolveOneQuestionType(ctx context.Context,
|
||||
host string, questionType uint16,
|
||||
) (addresses []netip.Addr, err error) {
|
||||
queryMessage := &dns.Msg{}
|
||||
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
|
||||
queryWire, err := queryMessage.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("packing DNS query: %w", err)
|
||||
}
|
||||
|
||||
// Try every DoH server and every of each of their IP until we get a non-empty
|
||||
// successful response.
|
||||
errs := make([]error, 0)
|
||||
for _, dohServer := range c.dohServers {
|
||||
dohURL, err := url.Parse(dohServer.URL)
|
||||
if err != nil {
|
||||
errs = append(errs,
|
||||
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
|
||||
continue
|
||||
}
|
||||
|
||||
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
|
||||
if c.ipv6Supported {
|
||||
// Prefer IPv6 addresses if IPv6 is supported
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
|
||||
}
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
||||
|
||||
for _, dohServerIP := range dohServerIPs {
|
||||
const defaultDoHPort uint16 = 443
|
||||
port := defaultDoHPort
|
||||
if portStr := dohURL.Port(); portStr != "" {
|
||||
port, err = parseDestinationPort(portStr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port)
|
||||
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
|
||||
switch {
|
||||
case err != nil:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
|
||||
dohServer.URL, dohServerAddrPort, err))
|
||||
continue
|
||||
case responseMessage.Rcode != dns.RcodeSuccess:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
|
||||
dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
|
||||
continue
|
||||
}
|
||||
addresses := answersToNetipAddrs(responseMessage)
|
||||
if len(addresses) == 0 {
|
||||
continue
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("resolving %s %s: %w",
|
||||
dns.TypeToString[questionType], host, errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||
dohURL *url.URL, dohServerAddrPort netip.AddrPort,
|
||||
) (responseMessage *dns.Msg, err error) {
|
||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := cleanup()
|
||||
if err == nil && closeErr != nil {
|
||||
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
requestBody := bytes.NewReader(queryWire)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/dns-message")
|
||||
request.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responseData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
_ = response.Body.Close()
|
||||
return nil, fmt.Errorf("reading response body: %w", err)
|
||||
}
|
||||
|
||||
err = response.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("closing response body: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("response status code is %s (data length %d)",
|
||||
response.Status, len(responseData))
|
||||
}
|
||||
|
||||
responseMessage = new(dns.Msg)
|
||||
err = responseMessage.Unpack(responseData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing DoH response: %w", err)
|
||||
}
|
||||
|
||||
return responseMessage, nil
|
||||
}
|
||||
|
||||
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
addresses = make([]netip.Addr, 0, len(message.Answer))
|
||||
for _, answer := range message.Answer {
|
||||
switch record := answer.(type) {
|
||||
case *dns.A:
|
||||
address, ok := netip.AddrFromSlice(record.A)
|
||||
if ok {
|
||||
addresses = append(addresses, address.Unmap())
|
||||
}
|
||||
case *dns.AAAA:
|
||||
address, ok := netip.AddrFromSlice(record.AAAA)
|
||||
if ok {
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
func parseDestinationPort(portStr string) (port uint16, err error) {
|
||||
portUint, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
const maxPortUint = 65535
|
||||
switch {
|
||||
case portUint == 0:
|
||||
return 0, errors.New("port cannot be 0")
|
||||
case portUint > maxPortUint:
|
||||
return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint)
|
||||
}
|
||||
return uint16(portUint), nil
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
//go:build integration
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Client_ResolveName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := t.Context()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
sourceMatcher := listenAddrPortMatcher{}
|
||||
destinationMatcher := destinationAddrPortMatcher{
|
||||
expected: netip.AddrPortFrom(netip.Addr{}, 443),
|
||||
}
|
||||
|
||||
// Add rule
|
||||
firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
|
||||
).DoAndReturn(func(
|
||||
_ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
|
||||
) error {
|
||||
sourceMatcher.expected = source
|
||||
destinationMatcher.expected = destination
|
||||
return nil
|
||||
})
|
||||
|
||||
// Removal rule
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
|
||||
).Return(nil).After(firstCall)
|
||||
|
||||
settings := Settings{
|
||||
DefaultInterface: "eth0",
|
||||
IPv6Supported: ptrTo(false),
|
||||
Firewall: firewall,
|
||||
UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
|
||||
}
|
||||
client := New(settings)
|
||||
|
||||
addresses, err := client.ResolveName(ctx, "github.com")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, addresses)
|
||||
}
|
||||
|
||||
func Test_answersToNetipAddrs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
message *dns.Msg
|
||||
expected []netip.Addr
|
||||
}{
|
||||
"nil_message": {},
|
||||
"no_answers": {
|
||||
message: &dns.Msg{},
|
||||
expected: []netip.Addr{},
|
||||
},
|
||||
"a_record": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||
},
|
||||
"aaaa_record": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
},
|
||||
"mixed_records": {
|
||||
message: &dns.Msg{Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
}},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testCase := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addresses := answersToNetipAddrs(testCase.message)
|
||||
assert.Equal(t, testCase.expected, addresses)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
DefaultInterface string
|
||||
IPv6Supported *bool
|
||||
Firewall Firewall
|
||||
UpstreamResolvers []provider.Provider
|
||||
}
|
||||
|
||||
func (s *Settings) validate() error {
|
||||
switch {
|
||||
case s.DefaultInterface == "":
|
||||
return errors.New("default interface is not set")
|
||||
case s.IPv6Supported == nil:
|
||||
return errors.New("IPv6 support field is not set")
|
||||
case s.Firewall == nil:
|
||||
return errors.New("firewall is not set")
|
||||
case len(s.UpstreamResolvers) == 0:
|
||||
return errors.New("no upstream resolvers provided")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
//go:build !windows
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func closeFD(fd int) {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
func newTCPSockStream(family int) (fd int, err error) {
|
||||
fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return 0, err
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func bindFD(fd int, address netip.AddrPort) error {
|
||||
bindAddr := makeSockAddr(address)
|
||||
return unix.Bind(fd, bindAddr)
|
||||
}
|
||||
|
||||
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||
err := unix.Connect(fd, makeSockAddr(destination))
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case !errors.Is(err, unix.EINPROGRESS):
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
bitsIndex := fd / 64 //nolint:mnd
|
||||
if bitsIndex >= len(unix.FdSet{}.Bits) {
|
||||
return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd)
|
||||
}
|
||||
wset := &unix.FdSet{}
|
||||
wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
|
||||
eset := &unix.FdSet{}
|
||||
eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd
|
||||
const selectTimeout = 50 * time.Millisecond
|
||||
timeval := unix.NsecToTimeval(int64(selectTimeout))
|
||||
|
||||
// Wait for the FD to become writable or hit an error state
|
||||
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue // Syscall interrupted, try again
|
||||
}
|
||||
return fmt.Errorf("select error: %w", err)
|
||||
} else if n == 0 {
|
||||
continue // no status change yet
|
||||
}
|
||||
|
||||
// Check if the socket encountered an error
|
||||
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getsockopt error: %w", err)
|
||||
} else if n != 0 {
|
||||
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||
sockAddr, err := unix.Getsockname(fd)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err)
|
||||
}
|
||||
|
||||
sourceAddrPort, err = sockAddrToAddrPort(sockAddr)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
return sourceAddrPort, nil
|
||||
}
|
||||
|
||||
func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr {
|
||||
if addressPort.Addr().Is4() {
|
||||
return &unix.SockaddrInet4{
|
||||
Port: int(addressPort.Port()),
|
||||
Addr: addressPort.Addr().As4(),
|
||||
}
|
||||
}
|
||||
return &unix.SockaddrInet6{
|
||||
Port: int(addressPort.Port()),
|
||||
Addr: addressPort.Addr().As16(),
|
||||
}
|
||||
}
|
||||
|
||||
func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) {
|
||||
switch typedSockAddr := sockAddr.(type) {
|
||||
case *unix.SockaddrInet4:
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
|
||||
case *unix.SockaddrInet6:
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec
|
||||
default:
|
||||
return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build windows
|
||||
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func closeFD(fd int) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func newTCPSockStream(family int) (fd int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func bindFD(fd int, address netip.AddrPort) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
Reference in New Issue
Block a user