Change tests to be more integration oriented

This commit is contained in:
Quentin McGaw
2026-06-09 14:04:32 +00:00
parent dd07205b85
commit b5366b9e44
7 changed files with 128 additions and 566 deletions
+19 -11
View File
@@ -4,7 +4,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/netip"
"strconv"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
) )
@@ -18,12 +21,9 @@ type Client struct {
ipv6Supported bool ipv6Supported bool
firewall Firewall firewall Firewall
dohServers []provider.DoHServer dohServers []provider.DoHServer
baseTransport *http.Transport
httpsPort uint16
} }
func New(settings Settings) *Client { func New(settings Settings) *Client {
settings.setDefaults()
if err := settings.validate(); err != nil { if err := settings.validate(); err != nil {
panic(fmt.Sprintf("invalid settings: %v", err)) // programming error panic(fmt.Sprintf("invalid settings: %v", err)) // programming error
} }
@@ -32,30 +32,38 @@ func New(settings Settings) *Client {
dohServers[i] = upstreamResolver.DoH dohServers[i] = upstreamResolver.DoH
} }
const defaultHTTPSPort = 443
return &Client{ return &Client{
outboundInterface: settings.DefaultInterface, outboundInterface: settings.DefaultInterface,
ipv6Supported: *settings.IPv6Supported, ipv6Supported: *settings.IPv6Supported,
firewall: settings.Firewall, firewall: settings.Firewall,
dohServers: dohServers, dohServers: dohServers,
baseTransport: settings.BaseTransport,
httpsPort: defaultHTTPSPort,
} }
} }
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) (
httpClient *http.Client, cleanup func() error, err error, httpClient *http.Client, cleanup func() error, err error,
) { ) {
resolvedIPs, err := c.ResolveName(ctx, domain) host, portStr, err := net.SplitHostPort(hostname)
if err != nil {
return nil, nil, fmt.Errorf("splitting host and port: %w", err)
}
resolvedIPs, err := c.ResolveName(ctx, host)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err) return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 { } else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", domain) return nil, nil, fmt.Errorf("no IP address found for name %q", host)
} }
portUint, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, nil, fmt.Errorf("parsing port: %w", err)
}
port := uint16(portUint)
errs := make([]error, 0, len(resolvedIPs)) errs := make([]error, 0, len(resolvedIPs))
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip) addrPort := netip.AddrPortFrom(ip, port)
httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
continue continue
@@ -63,5 +71,5 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
return httpClient, cleanup, nil return httpClient, cleanup, nil
} }
return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...)) return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...))
} }
-180
View File
@@ -1,185 +1,5 @@
package restrictednet package restrictednet
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"sync"
"syscall"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func ptrTo[T any](value T) *T { func ptrTo[T any](value T) *T {
return &value return &value
} }
func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport {
return &http.Transport{
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
go func() {
defer serverConn.Close()
reader := bufio.NewReader(serverConn)
request, err := http.ReadRequest(reader)
if err != nil {
return
}
response, err := handler(request.Host, request.Body)
if err != nil {
return
}
// Read the response body and re-create it to avoid linting
// complaining that the response body must be closed.
responseData, err := io.ReadAll(response.Body)
if err != nil {
return
}
_ = response.Body.Close()
response.Body = io.NopCloser(bytes.NewReader(responseData))
_ = response.Write(serverConn)
}()
return clientConn, nil
},
}
}
func expectFirewallCallPair(
firewall *MockFirewall,
addContext context.Context, //nolint:revive
destinationIP netip.Addr,
destinationPort uint16,
addErr error,
removeErr error,
) {
destination := netip.AddrPortFrom(destinationIP, destinationPort)
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
addContext, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(
_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return addErr
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destination, true,
).Return(removeErr)
}
func urlToHostnamePort(rawURL string, port uint16) string {
parsedURL, err := url.Parse(rawURL)
if err != nil {
panic(err) // programming error in test
}
parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10))
return parsedURL.String()
}
func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte {
t.Helper()
queryData, err := io.ReadAll(queryReader)
require.NoError(t, err)
query := new(dns.Msg)
err = query.Unpack(queryData)
require.NoError(t, err)
response := new(dns.Msg)
response.SetReply(query)
response.Answer = append(response.Answer, answers...)
wire, err := response.Pack()
require.NoError(t, err)
return wire
}
func startTCPAccepter(t *testing.T) (port uint16) {
t.Helper()
// Find a port available for both TCP IPv4 and TCP IPv6
listeners := make([]net.Listener, 2) // IPv4 + IPv6
netConfig := net.ListenConfig{}
var listenersToClose []net.Listener
for t.Context().Err() == nil {
// Find an available port for IPv4
listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0)
listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String())
require.NoError(t, err)
listeners[0] = listener
port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert
// Check if that port is also available for IPv6
listeningAddress = netip.AddrPortFrom(
netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}),
port,
)
listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String())
if err == nil {
listeners[1] = listener
break // success, we found a port available for both IPv4 and IPv6
}
var opErr *net.OpError
if errors.As(err, &opErr) {
var sysErr *os.SyscallError
if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) {
// Port found for IPv4 is already in use for IPv6, try another port
// We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS.
listenersToClose = append(listenersToClose, listeners[0])
continue
}
}
}
for _, listener := range listenersToClose {
err := listener.Close()
assert.NoError(t, err)
}
var ready sync.WaitGroup
ready.Add(len(listeners))
for _, listener := range listeners {
t.Cleanup(func() {
err := listener.Close()
assert.NoError(t, err)
})
go func() {
ready.Done()
for {
connection, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil {
return
}
assert.NoError(t, err)
return
}
err = connection.Close()
assert.NoError(t, err)
}
}()
}
ready.Wait()
return port
}
+25 -22
View File
@@ -17,15 +17,13 @@ import (
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections. // The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
) (httpClient *http.Client, cleanup func() error, err error) { ) (httpClient *http.Client, cleanup func() error, err error) {
fd, sourceAddrPort, err := bindSourceConnection(destinationIP) fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err) return nil, nil, fmt.Errorf("binding source port: %w", err)
} }
destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort)
const remove = false const remove = false
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove) sourceAddrPort, destinationAddrPort, remove)
@@ -42,7 +40,8 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
return nil, nil, fmt.Errorf("connecting source socket: %w", err) return nil, nil, fmt.Errorf("connecting source socket: %w", err)
} }
httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection) dial := makeDial(connection, destinationTLSName)
httpClient = newHTTPSClient(destinationTLSName, dial)
cleanup = func() error { cleanup = func() error {
var errs []error var errs []error
httpClient.CloseIdleConnections() httpClient.CloseIdleConnections()
@@ -53,7 +52,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
} }
err = connection.Close() err = connection.Close()
if err != nil { if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing connection: %w", err)) errs = append(errs, fmt.Errorf("closing connection: %w", err))
} }
if len(errs) > 0 { if len(errs) > 0 {
@@ -64,21 +63,31 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
return httpClient, cleanup, nil return httpClient, cleanup, nil
} }
func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client { type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
transport := baseTransport.Clone()
transport.Proxy = nil func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
transport.MaxIdleConns = 1 const timeout = 5 * time.Second
transport.MaxIdleConnsPerHost = 1 transport := &http.Transport{
transport.MaxConnsPerHost = 1 MaxIdleConns: 1,
transport.IdleConnTimeout = time.Second MaxIdleConnsPerHost: 1,
transport.TLSClientConfig = &tls.Config{ MaxConnsPerHost: 1,
IdleConnTimeout: time.Second,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName, ServerName: destinationTLSName,
},
DialContext: dial,
} }
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func makeDial(connection net.Conn, tlsName string) dialFunc {
_, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String())
expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) expectedAddress := net.JoinHostPort(tlsName, destinationPort)
transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { return func(_ context.Context, network, address string) (net.Conn, error) {
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
default: default:
@@ -89,12 +98,6 @@ func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, co
} }
return connection, nil return connection, nil
} }
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: transport,
}
} }
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {
+43 -22
View File
@@ -2,12 +2,14 @@ package restrictednet
import ( import (
"context" "context"
"net" "fmt"
"net/http"
"net/netip" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -33,31 +35,40 @@ func (m listenAddrPortMatcher) String() string {
return "is a valid netip.AddrPort with a valid IP and non-zero port" return "is a valid netip.AddrPort with a valid IP and non-zero port"
} }
type destinationAddrPortMatcher struct {
expected netip.AddrPort
}
func (m destinationAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Port() == m.expected.Port()
}
func (m destinationAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "matches the port " + fmt.Sprint(m.expected.Port())
}
func Test_Client_OpenHTTPS(t *testing.T) { func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel() t.Parallel()
ctx := t.Context() ctx := t.Context()
netConfig := net.ListenConfig{}
listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
_ = listener.Close()
})
listeningPort := uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert
go func() {
connection, acceptErr := listener.Accept()
if acceptErr == nil {
_ = connection.Close()
}
}()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
destination := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), listeningPort) const destinationTLSName = "one.one.one.one"
destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443)
firewall := NewMockFirewall(ctrl)
sourceMatcher := listenAddrPortMatcher{} sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destination, false, ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false,
).DoAndReturn(func(_ context.Context, ).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool, _, _ string, source, _ netip.AddrPort, _ bool,
) error { ) error {
@@ -65,7 +76,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
return nil return nil
}) })
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destination, true, context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
) )
const ipv6Supported = false const ipv6Supported = false
@@ -77,13 +88,23 @@ func Test_Client_OpenHTTPS(t *testing.T) {
UpstreamResolvers: upstreamResolvers, UpstreamResolvers: upstreamResolvers,
} }
client := New(settings) client := New(settings)
client.httpsPort = listeningPort
httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, httpClient) require.NotNil(t, httpClient)
require.NotNil(t, cleanup) require.NotNil(t, cleanup)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil)
require.NoError(t, err)
response, err := httpClient.Do(request)
t.Cleanup(func() {
response.Body.Close()
})
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)
err = cleanup() err = cleanup()
require.NoError(t, err) require.NoError(t, err)
} }
+9 -7
View File
@@ -76,15 +76,17 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
dohServerIPs = append(dohServerIPs, dohServer.IPv4...) dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
for _, dohServerIP := range dohServerIPs { for _, dohServerIP := range dohServerIPs {
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) const defaultDoHPort = 443
dohServerAddrPort := netip.AddrPortFrom(dohServerIP, defaultDoHPort)
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort)
switch { switch {
case err != nil: case err != nil:
errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w", errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w",
dohServer.URL, dohServerIP, err)) dohServer.URL, dohServerAddrPort, err))
continue continue
case responseMessage.Rcode != dns.RcodeSuccess: case responseMessage.Rcode != dns.RcodeSuccess:
errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s", errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s",
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode]))
continue continue
} }
addresses := answersToNetipAddrs(responseMessage) addresses := answersToNetipAddrs(responseMessage)
@@ -104,9 +106,9 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
} }
func (c *Client) doHQuery(ctx context.Context, queryWire []byte, func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerIP netip.Addr, dohURL *url.URL, dohServerAddrPort netip.AddrPort,
) (responseMessage *dns.Msg, err error) { ) (responseMessage *dns.Msg, err error) {
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP) httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err) return nil, fmt.Errorf("opening https connection: %w", err)
} }
+20 -304
View File
@@ -1,15 +1,9 @@
package restrictednet package restrictednet
import ( import (
"bytes"
"context" "context"
"errors"
"io"
"net" "net"
"net/http"
"net/netip" "net/netip"
"net/url"
"sync/atomic"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@@ -21,320 +15,42 @@ import (
func Test_Client_ResolveName(t *testing.T) { func Test_Client_ResolveName(t *testing.T) {
t.Parallel() t.Parallel()
ctx := t.Context()
testCases := map[string]struct {
ipv6Supported bool
upstreamResolvers []provider.Provider
expectedAddresses []netip.Addr
errorContains string
expectedDestIPs []netip.Addr
responder func(host string, requestBody io.Reader) (*http.Response, error)
}{
"success_single_server_ipv4": {
upstreamResolvers: []provider.Provider{{
DoH: provider.DoHServer{
URL: "https://resolver-1.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
}},
expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
wire := responseWireForQuery(t, requestBody, &dns.A{
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
})
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
},
},
"fallback_between_servers": {
upstreamResolvers: []provider.Provider{
{
DoH: provider.DoHServer{
URL: "https://resolver-1.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
},
{
DoH: provider.DoHServer{
URL: "https://resolver-2.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
},
},
expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
responder: func(host string, requestBody io.Reader) (*http.Response, error) {
if host == "resolver-1.local" ||
len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" {
return &http.Response{
StatusCode: http.StatusBadGateway,
Status: "502 Bad Gateway",
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
}, nil
}
wire := responseWireForQuery(t, requestBody, &dns.A{
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{2, 2, 2, 2},
})
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
},
},
"fallback_between_ips": {
upstreamResolvers: []provider.Provider{{
DoH: provider.DoHServer{
URL: "https://resolver.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
},
}},
expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
responder: func() func(host string, requestBody io.Reader) (*http.Response, error) {
var calls atomic.Int32
return func(_ string, requestBody io.Reader) (*http.Response, error) {
if calls.Add(1) == 1 { // first call fails
return &http.Response{
StatusCode: http.StatusNotFound,
Status: "502 Bad Gateway",
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
}, nil
}
wire := responseWireForQuery(t, requestBody, &dns.A{
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 2},
})
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
}
}(), //nolint:bodyclose
},
"dns_rcode_error_servfail": {
upstreamResolvers: []provider.Provider{{
DoH: provider.DoHServer{
URL: "https://resolver.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
}},
errorContains: "SERVFAIL",
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
queryWire, err := io.ReadAll(requestBody)
require.NoError(t, err)
query := new(dns.Msg)
err = query.Unpack(queryWire)
require.NoError(t, err)
response := new(dns.Msg)
response.SetReply(query)
response.Rcode = dns.RcodeServerFailure
wire, err := response.Pack()
require.NoError(t, err)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
},
},
"no_answer": {
upstreamResolvers: []provider.Provider{{
DoH: provider.DoHServer{
URL: "https://resolver.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
}},
expectedAddresses: nil,
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
wire := responseWireForQuery(t, requestBody)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
},
},
"ipv6_preference": {
ipv6Supported: true,
upstreamResolvers: []provider.Provider{{
DoH: provider.DoHServer{
URL: "https://resolver.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
IPv6: []netip.Addr{netip.MustParseAddr("::1")},
},
}},
expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
expectedDestIPs: []netip.Addr{
netip.MustParseAddr("::1"),
netip.MustParseAddr("::1"),
netip.MustParseAddr("127.0.0.1"),
},
responder: func(_ string, requestBody io.Reader) (*http.Response, error) {
queryWire, err := io.ReadAll(requestBody)
require.NoError(t, err)
query := new(dns.Msg)
err = query.Unpack(queryWire)
require.NoError(t, err)
if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA {
wire := responseWireForQuery(t, bytes.NewReader(queryWire))
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
}
wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{
Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
})
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil
},
},
"all_servers_fail": {
upstreamResolvers: []provider.Provider{
{DoH: provider.DoHServer{
URL: "https://resolver-1.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
}},
{DoH: provider.DoHServer{
URL: "https://resolver-2.local/dns-query",
IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
}},
},
errorContains: "resolving host",
expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")},
responder: func(_ string, _ io.Reader) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Status: "502 Bad Gateway",
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
}, nil
},
},
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl) firewall := NewMockFirewall(ctrl)
port := startTCPAccepter(t) sourceMatcher := listenAddrPortMatcher{}
destinationMatcher := destinationAddrPortMatcher{
for _, destinationIP := range testCase.expectedDestIPs { expected: netip.AddrPortFrom(netip.Addr{}, 443),
expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil)
} }
resolvers := make([]provider.Provider, len(testCase.upstreamResolvers)) // Add rule
copy(resolvers, testCase.upstreamResolvers) firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
for i := range resolvers { ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false,
resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port) ).DoAndReturn(func(
} _ context.Context, _, _ string, source, destination netip.AddrPort, _ bool,
) error {
settings := Settings{ sourceMatcher.expected = source
DefaultInterface: "eth0", destinationMatcher.expected = destination
IPv6Supported: ptrTo(testCase.ipv6Supported), return nil
Firewall: firewall,
UpstreamResolvers: resolvers,
BaseTransport: newInterceptTransport(testCase.responder),
}
client := New(settings)
client.httpsPort = port
addresses, err := client.ResolveName(t.Context(), "github.com")
assert.Equal(t, testCase.expectedAddresses, addresses)
if testCase.errorContains != "" {
require.Error(t, err)
assert.ErrorContains(t, err, testCase.errorContains)
} else {
require.NoError(t, err)
}
})
}
}
func Test_Client_doHQuery(t *testing.T) {
t.Parallel()
query := new(dns.Msg)
query.SetQuestion("example.com.", dns.TypeA)
queryWire, err := query.Pack()
require.NoError(t, err)
responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
}) })
testCases := map[string]struct { // Removal rule
response *http.Response firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
addFirewallRuleErr error context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true,
removeFirewallRuleErr error ).Return(nil).After(firstCall)
errorContains string
expectedIPs []netip.Addr
}{
"success": {
response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))},
expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
},
"http_status_not_ok": {
response: &http.Response{
StatusCode: http.StatusBadGateway,
Status: "502 Bad Gateway",
Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))),
},
errorContains: "response status code is 502 Bad Gateway",
},
"malformed_dns_response": {
response: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("not-dns")),
},
errorContains: "parsing DoH response",
},
"cleanup_error": {
response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))},
removeFirewallRuleErr: errors.New("cleanup failed"),
errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
port := startTCPAccepter(t)
expectFirewallCallPair(
firewall,
context.Background(),
netip.MustParseAddr("127.0.0.1"),
port,
testCase.addFirewallRuleErr,
testCase.removeFirewallRuleErr,
)
settings := Settings{ settings := Settings{
DefaultInterface: "eth0", DefaultInterface: "eth0",
IPv6Supported: ptrTo(false), IPv6Supported: ptrTo(false),
Firewall: firewall, Firewall: firewall,
UpstreamResolvers: []provider.Provider{provider.Google()}, UpstreamResolvers: []provider.Provider{provider.Cloudflare()},
BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) {
return testCase.response, nil
}),
} }
client := New(settings) client := New(settings)
client.httpsPort = port
dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port)) addresses, err := client.ResolveName(ctx, "github.com")
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, addresses)
message, err := client.doHQuery(
context.Background(),
queryWire,
dohURL,
netip.MustParseAddr("127.0.0.1"),
)
if testCase.errorContains != "" {
require.Error(t, err)
assert.ErrorContains(t, err, testCase.errorContains)
return
}
require.NoError(t, err)
addresses := answersToNetipAddrs(message)
assert.Equal(t, testCase.expectedIPs, addresses)
})
}
} }
func Test_answersToNetipAddrs(t *testing.T) { func Test_answersToNetipAddrs(t *testing.T) {
-8
View File
@@ -2,7 +2,6 @@ package restrictednet
import ( import (
"errors" "errors"
"net/http"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
) )
@@ -12,13 +11,6 @@ type Settings struct {
IPv6Supported *bool IPv6Supported *bool
Firewall Firewall Firewall Firewall
UpstreamResolvers []provider.Provider UpstreamResolvers []provider.Provider
BaseTransport *http.Transport
}
func (s *Settings) setDefaults() {
if s.BaseTransport == nil {
s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert
}
} }
func (s *Settings) validate() error { func (s *Settings) validate() error {