mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 21:37:31 +02:00
initial
This commit is contained in:
@@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
|
|||||||
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
|
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
|
||||||
- Use `netip` types instead of `net` types whenever possible
|
- Use `netip` types instead of `net` types whenever possible
|
||||||
- Use constants instead of variables whenever possible, especially function-local inline constants.
|
- Use constants instead of variables whenever possible, especially function-local inline constants.
|
||||||
|
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
|
||||||
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
|
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
|
||||||
- `panic`:
|
- `panic`:
|
||||||
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
|
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
|
||||||
@@ -127,6 +128,7 @@ The Go formatter used is gofumpt.
|
|||||||
### Errors
|
### Errors
|
||||||
|
|
||||||
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
|
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
|
||||||
|
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
|
||||||
- In rare cases, you can just use `return err` notably:
|
- In rare cases, you can just use `return err` notably:
|
||||||
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
|
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
|
||||||
- If the current function only statement is the call to another function, for example:
|
- If the current function only statement is the call to another function, for example:
|
||||||
@@ -179,6 +181,8 @@ The Go formatter used is gofumpt.
|
|||||||
|
|
||||||
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
|
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
|
||||||
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
|
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
|
||||||
|
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
|
||||||
|
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
|
||||||
|
|
||||||
## Validation checklist
|
## Validation checklist
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
|
|||||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||||
AcceptOutput(ctx context.Context, protocol, intf string,
|
AcceptOutput(ctx context.Context, protocol, intf string,
|
||||||
ip netip.Addr, port uint16, remove bool) error
|
ip netip.Addr, port uint16, remove bool) error
|
||||||
|
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
|
||||||
|
source, destination netip.AddrPort, remove bool) error
|
||||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||||
subnet netip.Prefix, remove bool) error
|
subnet netip.Prefix, remove bool) error
|
||||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||||
|
|||||||
@@ -177,6 +177,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
|||||||
return c.runIP6tablesInstruction(ctx, instruction)
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||||
|
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||||
|
) error {
|
||||||
|
if source.Addr().BitLen() != destination.Addr().BitLen() {
|
||||||
|
return fmt.Errorf("source and destination address families do not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaceFlag := "-o " + intf
|
||||||
|
if intf == "*" { // all interfaces
|
||||||
|
interfaceFlag = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT",
|
||||||
|
appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(),
|
||||||
|
interfaceFlag, protocol, protocol, destination.Port())
|
||||||
|
if destination.Addr().Is4() {
|
||||||
|
return c.runIptablesInstruction(ctx, instruction)
|
||||||
|
} else if c.ip6Tables == "" {
|
||||||
|
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
|
||||||
|
}
|
||||||
|
return c.runIP6tablesInstruction(ctx, instruction)
|
||||||
|
}
|
||||||
|
|
||||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||||
// If remove is true, the rule is removed instead of added.
|
// If remove is true, the rule is removed instead of added.
|
||||||
|
|||||||
@@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
|
|||||||
) error {
|
) error {
|
||||||
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
|
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||||
|
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||||
|
) error {
|
||||||
|
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
|
||||||
|
source, destination, remove)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is a client for making restricted network requests,
|
||||||
|
// such as opening temporary firewall rules for HTTPS connections.
|
||||||
|
// It is not meant to be high performance, although it can be used for
|
||||||
|
// multiple requests and concurrently.
|
||||||
|
type Client struct {
|
||||||
|
ipv6Supported bool
|
||||||
|
firewall Firewall
|
||||||
|
outboundInterface string
|
||||||
|
dohServers []provider.DoHServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(firewall Firewall, defaultInterface string, ipv6Supported bool,
|
||||||
|
upstreamResolvers []provider.Provider,
|
||||||
|
) (*Client, error) {
|
||||||
|
dohServers := make([]provider.DoHServer, len(upstreamResolvers))
|
||||||
|
for i, upstreamResolver := range upstreamResolvers {
|
||||||
|
dohServers[i] = upstreamResolver.DoH
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
firewall: firewall,
|
||||||
|
outboundInterface: defaultInterface,
|
||||||
|
ipv6Supported: ipv6Supported,
|
||||||
|
dohServers: dohServers,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
|
||||||
|
httpClient *http.Client, cleanup func() error, err error,
|
||||||
|
) {
|
||||||
|
resolvedIPs, err := c.ResolveName(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
||||||
|
} else if len(resolvedIPs) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedIP := resolvedIPs[0]
|
||||||
|
|
||||||
|
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpClient, cleanup, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type listenAddrPortMatcher struct {
|
||||||
|
expected netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m listenAddrPortMatcher) Matches(x any) bool {
|
||||||
|
ip, ok := x.(netip.AddrPort)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if m.expected.IsValid() {
|
||||||
|
return ip == m.expected
|
||||||
|
}
|
||||||
|
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m listenAddrPortMatcher) String() string {
|
||||||
|
if m.expected.IsValid() {
|
||||||
|
return "is the same as " + m.expected.String()
|
||||||
|
}
|
||||||
|
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
firewall := NewMockFirewall(ctrl)
|
||||||
|
|
||||||
|
destination := netip.MustParseAddrPort("1.2.3.4:443")
|
||||||
|
backgroundContext := context.Background()
|
||||||
|
sourceMatcher := listenAddrPortMatcher{}
|
||||||
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
|
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
|
||||||
|
).DoAndReturn(func(_ context.Context,
|
||||||
|
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||||
|
) error {
|
||||||
|
sourceMatcher.expected = source
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
|
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
|
||||||
|
)
|
||||||
|
|
||||||
|
const ipv6Supported = false
|
||||||
|
upstreamResolvers := []provider.Provider{provider.Google()}
|
||||||
|
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, httpClient)
|
||||||
|
require.NotNil(t, cleanup)
|
||||||
|
|
||||||
|
err = cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||||
|
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||||
|
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
||||||
|
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||||
|
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const httpsPort = 443
|
||||||
|
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
|
||||||
|
|
||||||
|
const remove = false
|
||||||
|
ctx := context.Background() // it's a quick firewall change, worth not passing a context
|
||||||
|
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||||
|
sourceAddrPort, destinationAddrPort, remove)
|
||||||
|
if err != nil {
|
||||||
|
_ = listener.Close()
|
||||||
|
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
|
||||||
|
cleanup = func() error {
|
||||||
|
var errs []error
|
||||||
|
httpClient.CloseIdleConnections()
|
||||||
|
const remove = true
|
||||||
|
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||||
|
sourceAddrPort, destinationAddrPort, remove)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||||
|
}
|
||||||
|
err = listener.Close()
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("closing listener: %w", err))
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return httpClient, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPSClient(destinationTLSName string,
|
||||||
|
destinationIP netip.Addr, sourceAddress netip.AddrPort,
|
||||||
|
) *http.Client {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
|
||||||
|
httpTransport.Proxy = nil
|
||||||
|
httpTransport.MaxIdleConns = 1
|
||||||
|
httpTransport.MaxIdleConnsPerHost = 1
|
||||||
|
httpTransport.IdleConnTimeout = time.Second
|
||||||
|
httpTransport.TLSClientConfig = &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: destinationTLSName,
|
||||||
|
}
|
||||||
|
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
|
||||||
|
|
||||||
|
const timeout = 5 * time.Second
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBoundDialContext(destinationAddress netip.Addr,
|
||||||
|
sourceAddress netip.AddrPort,
|
||||||
|
) func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
const httpsPort = 443
|
||||||
|
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
|
||||||
|
return func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
const timeout = 2 * time.Second
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
|
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
|
||||||
|
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
|
||||||
|
}
|
||||||
|
return connection, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindSourcePort(destinationIP netip.Addr) (
|
||||||
|
listener net.Listener, sourceAddr netip.AddrPort, err error,
|
||||||
|
) {
|
||||||
|
var bindAddr netip.Addr
|
||||||
|
if destinationIP.Is4() {
|
||||||
|
bindAddr = netip.AddrFrom4([4]byte{})
|
||||||
|
} else {
|
||||||
|
bindAddr = netip.AddrFrom16([16]byte{})
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
|
||||||
|
netip.AddrPortFrom(bindAddr, 0)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
|
||||||
|
sourceAddr = tcpAddr.AddrPort()
|
||||||
|
|
||||||
|
return listener, sourceAddr, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Firewall interface {
|
||||||
|
AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||||
|
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||||
|
) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
|
||||||
|
|
||||||
|
// Package restrictednet is a generated GoMock package.
|
||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
netip "net/netip"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockFirewall is a mock of Firewall interface.
|
||||||
|
type MockFirewall struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockFirewallMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||||
|
type MockFirewallMockRecorder struct {
|
||||||
|
mock *MockFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockFirewall creates a new mock instance.
|
||||||
|
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||||
|
mock := &MockFirewall{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptOutputFromIPPortToIPPort mocks base method.
|
||||||
|
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
|
||||||
|
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||||
|
}
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResolveName resolves the given host name to IP addresses using DoH servers,
|
||||||
|
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
|
||||||
|
// The host must be a single well-formed domain name, without port or path.
|
||||||
|
func (c *Client) ResolveName(ctx context.Context, host string) (
|
||||||
|
resolvedAddresses []netip.Addr, err error,
|
||||||
|
) {
|
||||||
|
questionTypes := make([]uint16, 0, 2)
|
||||||
|
if c.ipv6Supported {
|
||||||
|
questionTypes = append(questionTypes, dns.TypeAAAA)
|
||||||
|
}
|
||||||
|
questionTypes = append(questionTypes, dns.TypeA)
|
||||||
|
|
||||||
|
var addresses []netip.Addr
|
||||||
|
errs := make([]error, 0, len(questionTypes))
|
||||||
|
for _, questionType := range questionTypes {
|
||||||
|
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addresses = append(addresses, answerAddresses...)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(addresses) > 0:
|
||||||
|
return addresses, nil
|
||||||
|
case len(errs) == 0:
|
||||||
|
return nil, nil // no address found
|
||||||
|
default: // errors
|
||||||
|
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) resolveOneQuestionType(ctx context.Context,
|
||||||
|
host string, questionType uint16,
|
||||||
|
) (addresses []netip.Addr, err error) {
|
||||||
|
queryMessage := &dns.Msg{}
|
||||||
|
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
|
||||||
|
queryWire, err := queryMessage.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("packing DNS query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try every DoH server and every of each of their IP until we get a non-empty
|
||||||
|
// successful response.
|
||||||
|
errs := make([]error, 0)
|
||||||
|
for _, dohServer := range c.dohServers {
|
||||||
|
dohURL, err := url.Parse(dohServer.URL)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs,
|
||||||
|
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
|
||||||
|
if c.ipv6Supported {
|
||||||
|
// Prefer IPv6 addresses if IPv6 is supported
|
||||||
|
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
|
||||||
|
}
|
||||||
|
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
||||||
|
|
||||||
|
for _, dohServerIP := range dohServerIPs {
|
||||||
|
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w",
|
||||||
|
dohServer.URL, dohServerIP, err))
|
||||||
|
continue
|
||||||
|
case responseMessage.Rcode != dns.RcodeSuccess:
|
||||||
|
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s",
|
||||||
|
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode]))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addresses := answersToNetipAddrs(responseMessage)
|
||||||
|
if len(addresses) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return addresses, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("resolving %s %s: %w",
|
||||||
|
dns.TypeToString[questionType], host, errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||||
|
dohURL *url.URL, dohServerIP netip.Addr,
|
||||||
|
) (responseMessage *dns.Msg, err error) {
|
||||||
|
httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
closeErr := close()
|
||||||
|
if err == nil && closeErr != nil {
|
||||||
|
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
requestBody := bytes.NewReader(queryWire)
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
request.Header.Set("Content-Type", "application/dns-message")
|
||||||
|
request.Header.Set("Accept", "application/dns-message")
|
||||||
|
|
||||||
|
response, err := httpClient.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
responseData, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
_ = response.Body.Close()
|
||||||
|
return nil, fmt.Errorf("reading response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = response.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("closing response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("response status code is %s, data: %s",
|
||||||
|
response.Status, responseData)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseMessage = new(dns.Msg)
|
||||||
|
err = responseMessage.Unpack(responseData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing DoH response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseMessage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
|
||||||
|
if message == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
addresses = make([]netip.Addr, 0, len(message.Answer))
|
||||||
|
for _, answer := range message.Answer {
|
||||||
|
switch record := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
address, ok := netip.AddrFromSlice(record.A)
|
||||||
|
if ok {
|
||||||
|
addresses = append(addresses, address.Unmap())
|
||||||
|
}
|
||||||
|
case *dns.AAAA:
|
||||||
|
address, ok := netip.AddrFromSlice(record.AAAA)
|
||||||
|
if ok {
|
||||||
|
addresses = append(addresses, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addresses
|
||||||
|
}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
package restrictednet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_answersToNetipAddrs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
message *dns.Msg
|
||||||
|
expected []netip.Addr
|
||||||
|
errorIsNil bool
|
||||||
|
}{
|
||||||
|
"nil_message": {
|
||||||
|
message: nil,
|
||||||
|
expected: nil,
|
||||||
|
errorIsNil: true,
|
||||||
|
},
|
||||||
|
"no_answers": {
|
||||||
|
message: &dns.Msg{},
|
||||||
|
expected: []netip.Addr{},
|
||||||
|
errorIsNil: true,
|
||||||
|
},
|
||||||
|
"a_record": {
|
||||||
|
message: &dns.Msg{
|
||||||
|
Answer: []dns.RR{
|
||||||
|
&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||||
|
A: net.IP{1, 1, 1, 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||||
|
errorIsNil: true,
|
||||||
|
},
|
||||||
|
"aaaa_record": {
|
||||||
|
message: &dns.Msg{
|
||||||
|
Answer: []dns.RR{
|
||||||
|
&dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||||
|
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
|
||||||
|
errorIsNil: true,
|
||||||
|
},
|
||||||
|
"mixed_records": {
|
||||||
|
message: &dns.Msg{
|
||||||
|
Answer: []dns.RR{
|
||||||
|
&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||||
|
A: net.IP{1, 1, 1, 1},
|
||||||
|
},
|
||||||
|
&dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||||
|
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
|
||||||
|
errorIsNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
addresses := answersToNetipAddrs(testCase.message)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.expected, addresses)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user