mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 13:27:31 +02:00
Change tests to be more integration oriented
This commit is contained in:
@@ -4,7 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/provider"
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
)
|
)
|
||||||
@@ -18,12 +21,9 @@ type Client struct {
|
|||||||
ipv6Supported bool
|
ipv6Supported bool
|
||||||
firewall Firewall
|
firewall Firewall
|
||||||
dohServers []provider.DoHServer
|
dohServers []provider.DoHServer
|
||||||
baseTransport *http.Transport
|
|
||||||
httpsPort uint16
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(settings Settings) *Client {
|
func New(settings Settings) *Client {
|
||||||
settings.setDefaults()
|
|
||||||
if err := settings.validate(); err != nil {
|
if err := settings.validate(); err != nil {
|
||||||
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
|
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
|
||||||
}
|
}
|
||||||
@@ -32,30 +32,38 @@ func New(settings Settings) *Client {
|
|||||||
dohServers[i] = upstreamResolver.DoH
|
dohServers[i] = upstreamResolver.DoH
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultHTTPSPort = 443
|
|
||||||
return &Client{
|
return &Client{
|
||||||
outboundInterface: settings.DefaultInterface,
|
outboundInterface: settings.DefaultInterface,
|
||||||
ipv6Supported: *settings.IPv6Supported,
|
ipv6Supported: *settings.IPv6Supported,
|
||||||
firewall: settings.Firewall,
|
firewall: settings.Firewall,
|
||||||
dohServers: dohServers,
|
dohServers: dohServers,
|
||||||
baseTransport: settings.BaseTransport,
|
|
||||||
httpsPort: defaultHTTPSPort,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
|
func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) (
|
||||||
httpClient *http.Client, cleanup func() error, err error,
|
httpClient *http.Client, cleanup func() error, err error,
|
||||||
) {
|
) {
|
||||||
resolvedIPs, err := c.ResolveName(ctx, domain)
|
host, portStr, err := net.SplitHostPort(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
|
||||||
|
}
|
||||||
|
resolvedIPs, err := c.ResolveName(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
||||||
} else if len(resolvedIPs) == 0 {
|
} else if len(resolvedIPs) == 0 {
|
||||||
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
|
return nil, nil, fmt.Errorf("no IP address found for name %q", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
portUint, err := strconv.ParseUint(portStr, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("parsing port: %w", err)
|
||||||
|
}
|
||||||
|
port := uint16(portUint)
|
||||||
|
|
||||||
errs := make([]error, 0, len(resolvedIPs))
|
errs := make([]error, 0, len(resolvedIPs))
|
||||||
for _, ip := range resolvedIPs {
|
for _, ip := range resolvedIPs {
|
||||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip)
|
addrPort := netip.AddrPortFrom(ip, port)
|
||||||
|
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
|
errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
|
||||||
continue
|
continue
|
||||||
@@ -63,5 +71,5 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
|
|||||||
return httpClient, cleanup, nil
|
return httpClient, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...))
|
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,185 +1,5 @@
|
|||||||
package restrictednet
|
package restrictednet
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ptrTo[T any](value T) *T {
|
func ptrTo[T any](value T) *T {
|
||||||
return &value
|
return &value
|
||||||
}
|
}
|
||||||
|
|
||||||
func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport {
|
|
||||||
return &http.Transport{
|
|
||||||
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
clientConn, serverConn := net.Pipe()
|
|
||||||
go func() {
|
|
||||||
defer serverConn.Close()
|
|
||||||
|
|
||||||
reader := bufio.NewReader(serverConn)
|
|
||||||
request, err := http.ReadRequest(reader)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := handler(request.Host, request.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the response body and re-create it to avoid linting
|
|
||||||
// complaining that the response body must be closed.
|
|
||||||
responseData, err := io.ReadAll(response.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = response.Body.Close()
|
|
||||||
response.Body = io.NopCloser(bytes.NewReader(responseData))
|
|
||||||
|
|
||||||
_ = response.Write(serverConn)
|
|
||||||
}()
|
|
||||||
return clientConn, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func expectFirewallCallPair(
|
|
||||||
firewall *MockFirewall,
|
|
||||||
addContext context.Context, //nolint:revive
|
|
||||||
destinationIP netip.Addr,
|
|
||||||
destinationPort uint16,
|
|
||||||
addErr error,
|
|
||||||
removeErr error,
|
|
||||||
) {
|
|
||||||
destination := netip.AddrPortFrom(destinationIP, destinationPort)
|
|
||||||
sourceMatcher := listenAddrPortMatcher{}
|
|
||||||
|
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
|
||||||
addContext, "tcp", "eth0", sourceMatcher, destination, false,
|
|
||||||
).DoAndReturn(func(
|
|
||||||
_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool,
|
|
||||||
) error {
|
|
||||||
sourceMatcher.expected = source
|
|
||||||
return addErr
|
|
||||||
})
|
|
||||||
|
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
|
||||||
context.Background(), "tcp", "eth0", sourceMatcher, destination, true,
|
|
||||||
).Return(removeErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func urlToHostnamePort(rawURL string, port uint16) string {
|
|
||||||
parsedURL, err := url.Parse(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
panic(err) // programming error in test
|
|
||||||
}
|
|
||||||
parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10))
|
|
||||||
return parsedURL.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
queryData, err := io.ReadAll(queryReader)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
query := new(dns.Msg)
|
|
||||||
err = query.Unpack(queryData)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
response := new(dns.Msg)
|
|
||||||
response.SetReply(query)
|
|
||||||
response.Answer = append(response.Answer, answers...)
|
|
||||||
|
|
||||||
wire, err := response.Pack()
|
|
||||||
require.NoError(t, err)
|
|
||||||
return wire
|
|
||||||
}
|
|
||||||
|
|
||||||
func startTCPAccepter(t *testing.T) (port uint16) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
// Find a port available for both TCP IPv4 and TCP IPv6
|
|
||||||
listeners := make([]net.Listener, 2) // IPv4 + IPv6
|
|
||||||
netConfig := net.ListenConfig{}
|
|
||||||
var listenersToClose []net.Listener
|
|
||||||
for t.Context().Err() == nil {
|
|
||||||
// Find an available port for IPv4
|
|
||||||
listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0)
|
|
||||||
listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
listeners[0] = listener
|
|
||||||
port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert
|
|
||||||
|
|
||||||
// Check if that port is also available for IPv6
|
|
||||||
listeningAddress = netip.AddrPortFrom(
|
|
||||||
netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}),
|
|
||||||
port,
|
|
||||||
)
|
|
||||||
listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String())
|
|
||||||
if err == nil {
|
|
||||||
listeners[1] = listener
|
|
||||||
break // success, we found a port available for both IPv4 and IPv6
|
|
||||||
}
|
|
||||||
var opErr *net.OpError
|
|
||||||
if errors.As(err, &opErr) {
|
|
||||||
var sysErr *os.SyscallError
|
|
||||||
if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) {
|
|
||||||
// Port found for IPv4 is already in use for IPv6, try another port
|
|
||||||
// We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS.
|
|
||||||
listenersToClose = append(listenersToClose, listeners[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, listener := range listenersToClose {
|
|
||||||
err := listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ready sync.WaitGroup
|
|
||||||
ready.Add(len(listeners))
|
|
||||||
for _, listener := range listeners {
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ready.Done()
|
|
||||||
for {
|
|
||||||
connection, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = connection.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
ready.Wait()
|
|
||||||
|
|
||||||
return port
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,15 +17,13 @@ import (
|
|||||||
|
|
||||||
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||||
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||||
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
|
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
|
||||||
) (httpClient *http.Client, cleanup func() error, err error) {
|
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||||
fd, sourceAddrPort, err := bindSourceConnection(destinationIP)
|
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort)
|
|
||||||
|
|
||||||
const remove = false
|
const remove = false
|
||||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||||
sourceAddrPort, destinationAddrPort, remove)
|
sourceAddrPort, destinationAddrPort, remove)
|
||||||
@@ -42,7 +40,8 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
|
|||||||
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
|
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection)
|
dial := makeDial(connection, destinationTLSName)
|
||||||
|
httpClient = newHTTPSClient(destinationTLSName, dial)
|
||||||
cleanup = func() error {
|
cleanup = func() error {
|
||||||
var errs []error
|
var errs []error
|
||||||
httpClient.CloseIdleConnections()
|
httpClient.CloseIdleConnections()
|
||||||
@@ -53,7 +52,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
|
|||||||
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||||
}
|
}
|
||||||
err = connection.Close()
|
err = connection.Close()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
errs = append(errs, fmt.Errorf("closing connection: %w", err))
|
errs = append(errs, fmt.Errorf("closing connection: %w", err))
|
||||||
}
|
}
|
||||||
if len(errs) > 0 {
|
if len(errs) > 0 {
|
||||||
@@ -64,21 +63,31 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
|
|||||||
return httpClient, cleanup, nil
|
return httpClient, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client {
|
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
transport := baseTransport.Clone()
|
|
||||||
transport.Proxy = nil
|
|
||||||
transport.MaxIdleConns = 1
|
|
||||||
transport.MaxIdleConnsPerHost = 1
|
|
||||||
transport.MaxConnsPerHost = 1
|
|
||||||
transport.IdleConnTimeout = time.Second
|
|
||||||
transport.TLSClientConfig = &tls.Config{
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
ServerName: destinationTLSName,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
|
||||||
|
const timeout = 5 * time.Second
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 1,
|
||||||
|
MaxIdleConnsPerHost: 1,
|
||||||
|
MaxConnsPerHost: 1,
|
||||||
|
IdleConnTimeout: time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: destinationTLSName,
|
||||||
|
},
|
||||||
|
DialContext: dial,
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDial(connection net.Conn, tlsName string) dialFunc {
|
||||||
_, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String())
|
_, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String())
|
||||||
expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort)
|
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
|
||||||
transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) {
|
return func(_ context.Context, network, address string) (net.Conn, error) {
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
default:
|
default:
|
||||||
@@ -89,12 +98,6 @@ func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, co
|
|||||||
}
|
}
|
||||||
return connection, nil
|
return connection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const timeout = 5 * time.Second
|
|
||||||
return &http.Client{
|
|
||||||
Timeout: timeout,
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
|
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ package restrictednet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/qdm12/dns/v2/pkg/provider"
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,31 +35,40 @@ func (m listenAddrPortMatcher) String() string {
|
|||||||
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type destinationAddrPortMatcher struct {
|
||||||
|
expected netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m destinationAddrPortMatcher) Matches(x any) bool {
|
||||||
|
ip, ok := x.(netip.AddrPort)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if m.expected.IsValid() {
|
||||||
|
return ip == m.expected
|
||||||
|
}
|
||||||
|
return ip.IsValid() && ip.Port() == m.expected.Port()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m destinationAddrPortMatcher) String() string {
|
||||||
|
if m.expected.IsValid() {
|
||||||
|
return "is the same as " + m.expected.String()
|
||||||
|
}
|
||||||
|
return "matches the port " + fmt.Sprint(m.expected.Port())
|
||||||
|
}
|
||||||
|
|
||||||
func Test_Client_OpenHTTPS(t *testing.T) {
|
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
netConfig := net.ListenConfig{}
|
|
||||||
listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = listener.Close()
|
|
||||||
})
|
|
||||||
listeningPort := uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert
|
|
||||||
go func() {
|
|
||||||
connection, acceptErr := listener.Accept()
|
|
||||||
if acceptErr == nil {
|
|
||||||
_ = connection.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
firewall := NewMockFirewall(ctrl)
|
|
||||||
|
|
||||||
destination := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), listeningPort)
|
const destinationTLSName = "one.one.one.one"
|
||||||
|
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
|
||||||
|
|
||||||
|
firewall := NewMockFirewall(ctrl)
|
||||||
sourceMatcher := listenAddrPortMatcher{}
|
sourceMatcher := listenAddrPortMatcher{}
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
ctx, "tcp", "eth0", sourceMatcher, destination, false,
|
ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
|
||||||
).DoAndReturn(func(_ context.Context,
|
).DoAndReturn(func(_ context.Context,
|
||||||
_, _ string, source, _ netip.AddrPort, _ bool,
|
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||||
) error {
|
) error {
|
||||||
@@ -65,7 +76,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
context.Background(), "tcp", "eth0", sourceMatcher, destination, true,
|
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
|
||||||
)
|
)
|
||||||
|
|
||||||
const ipv6Supported = false
|
const ipv6Supported = false
|
||||||
@@ -77,13 +88,23 @@ func Test_Client_OpenHTTPS(t *testing.T) {
|
|||||||
UpstreamResolvers: upstreamResolvers,
|
UpstreamResolvers: upstreamResolvers,
|
||||||
}
|
}
|
||||||
client := New(settings)
|
client := New(settings)
|
||||||
client.httpsPort = listeningPort
|
|
||||||
|
|
||||||
httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1"))
|
httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, httpClient)
|
require.NotNil(t, httpClient)
|
||||||
require.NotNil(t, cleanup)
|
require.NotNil(t, cleanup)
|
||||||
|
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
response, err := httpClient.Do(request)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
response.Body.Close()
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, response.StatusCode)
|
||||||
|
|
||||||
err = cleanup()
|
err = cleanup()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,15 +76,17 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
|
|||||||
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
||||||
|
|
||||||
for _, dohServerIP := range dohServerIPs {
|
for _, dohServerIP := range dohServerIPs {
|
||||||
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP)
|
const defaultDoHPort = 443
|
||||||
|
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, defaultDoHPort)
|
||||||
|
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w",
|
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
|
||||||
dohServer.URL, dohServerIP, err))
|
dohServer.URL, dohServerAddrPort, err))
|
||||||
continue
|
continue
|
||||||
case responseMessage.Rcode != dns.RcodeSuccess:
|
case responseMessage.Rcode != dns.RcodeSuccess:
|
||||||
errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s",
|
errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
|
||||||
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode]))
|
dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
addresses := answersToNetipAddrs(responseMessage)
|
addresses := answersToNetipAddrs(responseMessage)
|
||||||
@@ -104,9 +106,9 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||||
dohURL *url.URL, dohServerIP netip.Addr,
|
dohURL *url.URL, dohServerAddrPort netip.AddrPort,
|
||||||
) (responseMessage *dns.Msg, err error) {
|
) (responseMessage *dns.Msg, err error) {
|
||||||
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP)
|
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("opening https connection: %w", err)
|
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,9 @@
|
|||||||
package restrictednet
|
package restrictednet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
@@ -21,320 +15,42 @@ import (
|
|||||||
|
|
||||||
func Test_Client_ResolveName(t *testing.T) {
|
func Test_Client_ResolveName(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
ctx := t.Context()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
testCases := map[string]struct {
|
firewall := NewMockFirewall(ctrl)
|
||||||
ipv6Supported bool
|
sourceMatcher := listenAddrPortMatcher{}
|
||||||
upstreamResolvers []provider.Provider
|
destinationMatcher := destinationAddrPortMatcher{
|
||||||
expectedAddresses []netip.Addr
|
expected: netip.AddrPortFrom(netip.Addr{}, 443),
|
||||||
errorContains string
|
|
||||||
expectedDestIPs []netip.Addr
|
|
||||||
responder func(host string, requestBody io.Reader) (*http.Response, error)
|
|
||||||
}{
|
|
||||||
"success_single_server_ipv4": {
|
|
||||||
upstreamResolvers: []provider.Provider{{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver-1.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
wire := responseWireForQuery(t, requestBody, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
|
||||||
A: net.IP{1, 1, 1, 1},
|
|
||||||
})
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"fallback_between_servers": {
|
|
||||||
upstreamResolvers: []provider.Provider{
|
|
||||||
{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver-1.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver-2.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func(host string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
if host == "resolver-1.local" ||
|
|
||||||
len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" {
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusBadGateway,
|
|
||||||
Status: "502 Bad Gateway",
|
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
wire := responseWireForQuery(t, requestBody, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
|
||||||
A: net.IP{2, 2, 2, 2},
|
|
||||||
})
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"fallback_between_ips": {
|
|
||||||
upstreamResolvers: []provider.Provider{{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func() func(host string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
var calls atomic.Int32
|
|
||||||
return func(_ string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
if calls.Add(1) == 1 { // first call fails
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusNotFound,
|
|
||||||
Status: "502 Bad Gateway",
|
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
wire := responseWireForQuery(t, requestBody, &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
|
||||||
A: net.IP{1, 1, 1, 2},
|
|
||||||
})
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
}
|
|
||||||
}(), //nolint:bodyclose
|
|
||||||
},
|
|
||||||
"dns_rcode_error_servfail": {
|
|
||||||
upstreamResolvers: []provider.Provider{{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
errorContains: "SERVFAIL",
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
queryWire, err := io.ReadAll(requestBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
query := new(dns.Msg)
|
|
||||||
err = query.Unpack(queryWire)
|
|
||||||
require.NoError(t, err)
|
|
||||||
response := new(dns.Msg)
|
|
||||||
response.SetReply(query)
|
|
||||||
response.Rcode = dns.RcodeServerFailure
|
|
||||||
wire, err := response.Pack()
|
|
||||||
require.NoError(t, err)
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"no_answer": {
|
|
||||||
upstreamResolvers: []provider.Provider{{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
expectedAddresses: nil,
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
wire := responseWireForQuery(t, requestBody)
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"ipv6_preference": {
|
|
||||||
ipv6Supported: true,
|
|
||||||
upstreamResolvers: []provider.Provider{{
|
|
||||||
DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
IPv6: []netip.Addr{netip.MustParseAddr("::1")},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
|
|
||||||
expectedDestIPs: []netip.Addr{
|
|
||||||
netip.MustParseAddr("::1"),
|
|
||||||
netip.MustParseAddr("::1"),
|
|
||||||
netip.MustParseAddr("127.0.0.1"),
|
|
||||||
},
|
|
||||||
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
queryWire, err := io.ReadAll(requestBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
query := new(dns.Msg)
|
|
||||||
err = query.Unpack(queryWire)
|
|
||||||
require.NoError(t, err)
|
|
||||||
if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA {
|
|
||||||
wire := responseWireForQuery(t, bytes.NewReader(queryWire))
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
}
|
|
||||||
wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
|
||||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
|
||||||
})
|
|
||||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"all_servers_fail": {
|
|
||||||
upstreamResolvers: []provider.Provider{
|
|
||||||
{DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver-1.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
}},
|
|
||||||
{DoH: provider.DoHServer{
|
|
||||||
URL: "https://resolver-2.local/dns-query",
|
|
||||||
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
errorContains: "resolving host",
|
|
||||||
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
|
|
||||||
responder: func(_ string, _ io.Reader) (*http.Response, error) {
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusBadGateway,
|
|
||||||
Status: "502 Bad Gateway",
|
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for testName, testCase := range testCases {
|
// Add rule
|
||||||
t.Run(testName, func(t *testing.T) {
|
firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
t.Parallel()
|
ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
|
||||||
ctrl := gomock.NewController(t)
|
).DoAndReturn(func(
|
||||||
|
_ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
|
||||||
firewall := NewMockFirewall(ctrl)
|
) error {
|
||||||
port := startTCPAccepter(t)
|
sourceMatcher.expected = source
|
||||||
|
destinationMatcher.expected = destination
|
||||||
for _, destinationIP := range testCase.expectedDestIPs {
|
return nil
|
||||||
expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvers := make([]provider.Provider, len(testCase.upstreamResolvers))
|
|
||||||
copy(resolvers, testCase.upstreamResolvers)
|
|
||||||
for i := range resolvers {
|
|
||||||
resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
settings := Settings{
|
|
||||||
DefaultInterface: "eth0",
|
|
||||||
IPv6Supported: ptrTo(testCase.ipv6Supported),
|
|
||||||
Firewall: firewall,
|
|
||||||
UpstreamResolvers: resolvers,
|
|
||||||
BaseTransport: newInterceptTransport(testCase.responder),
|
|
||||||
}
|
|
||||||
client := New(settings)
|
|
||||||
client.httpsPort = port
|
|
||||||
|
|
||||||
addresses, err := client.ResolveName(t.Context(), "github.com")
|
|
||||||
assert.Equal(t, testCase.expectedAddresses, addresses)
|
|
||||||
if testCase.errorContains != "" {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.ErrorContains(t, err, testCase.errorContains)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Client_doHQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
query := new(dns.Msg)
|
|
||||||
query.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
queryWire, err := query.Pack()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
|
||||||
A: net.IP{1, 1, 1, 1},
|
|
||||||
})
|
})
|
||||||
|
|
||||||
testCases := map[string]struct {
|
// Removal rule
|
||||||
response *http.Response
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
addFirewallRuleErr error
|
context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
|
||||||
removeFirewallRuleErr error
|
).Return(nil).After(firstCall)
|
||||||
errorContains string
|
|
||||||
expectedIPs []netip.Addr
|
settings := Settings{
|
||||||
}{
|
DefaultInterface: "eth0",
|
||||||
"success": {
|
IPv6Supported: ptrTo(false),
|
||||||
response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))},
|
Firewall: firewall,
|
||||||
expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
|
||||||
},
|
|
||||||
"http_status_not_ok": {
|
|
||||||
response: &http.Response{
|
|
||||||
StatusCode: http.StatusBadGateway,
|
|
||||||
Status: "502 Bad Gateway",
|
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
|
|
||||||
},
|
|
||||||
errorContains: "response status code is 502 Bad Gateway",
|
|
||||||
},
|
|
||||||
"malformed_dns_response": {
|
|
||||||
response: &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Body: io.NopCloser(bytes.NewBufferString("not-dns")),
|
|
||||||
},
|
|
||||||
errorContains: "parsing DoH response",
|
|
||||||
},
|
|
||||||
"cleanup_error": {
|
|
||||||
response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))},
|
|
||||||
removeFirewallRuleErr: errors.New("cleanup failed"),
|
|
||||||
errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
client := New(settings)
|
||||||
|
|
||||||
for name, testCase := range testCases {
|
addresses, err := client.ResolveName(ctx, "github.com")
|
||||||
t.Run(name, func(t *testing.T) {
|
require.NoError(t, err)
|
||||||
t.Parallel()
|
assert.NotEmpty(t, addresses)
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
|
|
||||||
firewall := NewMockFirewall(ctrl)
|
|
||||||
port := startTCPAccepter(t)
|
|
||||||
|
|
||||||
expectFirewallCallPair(
|
|
||||||
firewall,
|
|
||||||
context.Background(),
|
|
||||||
netip.MustParseAddr("127.0.0.1"),
|
|
||||||
port,
|
|
||||||
testCase.addFirewallRuleErr,
|
|
||||||
testCase.removeFirewallRuleErr,
|
|
||||||
)
|
|
||||||
|
|
||||||
settings := Settings{
|
|
||||||
DefaultInterface: "eth0",
|
|
||||||
IPv6Supported: ptrTo(false),
|
|
||||||
Firewall: firewall,
|
|
||||||
UpstreamResolvers: []provider.Provider{provider.Google()},
|
|
||||||
BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) {
|
|
||||||
return testCase.response, nil
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
client := New(settings)
|
|
||||||
client.httpsPort = port
|
|
||||||
|
|
||||||
dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
message, err := client.doHQuery(
|
|
||||||
context.Background(),
|
|
||||||
queryWire,
|
|
||||||
dohURL,
|
|
||||||
netip.MustParseAddr("127.0.0.1"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if testCase.errorContains != "" {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.ErrorContains(t, err, testCase.errorContains)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
addresses := answersToNetipAddrs(message)
|
|
||||||
assert.Equal(t, testCase.expectedIPs, addresses)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_answersToNetipAddrs(t *testing.T) {
|
func Test_answersToNetipAddrs(t *testing.T) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package restrictednet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/provider"
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
)
|
)
|
||||||
@@ -12,13 +11,6 @@ type Settings struct {
|
|||||||
IPv6Supported *bool
|
IPv6Supported *bool
|
||||||
Firewall Firewall
|
Firewall Firewall
|
||||||
UpstreamResolvers []provider.Provider
|
UpstreamResolvers []provider.Provider
|
||||||
BaseTransport *http.Transport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Settings) setDefaults() {
|
|
||||||
if s.BaseTransport == nil {
|
|
||||||
s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Settings) validate() error {
|
func (s *Settings) validate() error {
|
||||||
|
|||||||
Reference in New Issue
Block a user