From b5366b9e440cef2a15ff06d61dcba95d9e1bd7e7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 9 Jun 2026 14:04:32 +0000 Subject: [PATCH] Change tests to be more integration oriented --- internal/restrictednet/client.go | 30 ++- internal/restrictednet/helpers_test.go | 180 ------------- internal/restrictednet/https.go | 53 ++-- internal/restrictednet/https_test.go | 65 +++-- internal/restrictednet/resolve.go | 16 +- internal/restrictednet/resolve_test.go | 342 +++---------------------- internal/restrictednet/settings.go | 8 - 7 files changed, 128 insertions(+), 566 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 7b1547a6..82091a75 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -4,7 +4,10 @@ import ( "context" "errors" "fmt" + "net" "net/http" + "net/netip" + "strconv" "github.com/qdm12/dns/v2/pkg/provider" ) @@ -18,12 +21,9 @@ type Client struct { ipv6Supported bool firewall Firewall dohServers []provider.DoHServer - baseTransport *http.Transport - httpsPort uint16 } func New(settings Settings) *Client { - settings.setDefaults() if err := settings.validate(); err != nil { panic(fmt.Sprintf("invalid settings: %v", err)) // programming error } @@ -32,30 +32,38 @@ func New(settings Settings) *Client { dohServers[i] = upstreamResolver.DoH } - const defaultHTTPSPort = 443 return &Client{ outboundInterface: settings.DefaultInterface, ipv6Supported: *settings.IPv6Supported, firewall: settings.Firewall, dohServers: dohServers, - baseTransport: settings.BaseTransport, - httpsPort: defaultHTTPSPort, } } -func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( +func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) ( httpClient *http.Client, cleanup func() error, err error, ) { - resolvedIPs, err := c.ResolveName(ctx, domain) + host, portStr, err := net.SplitHostPort(hostname) + if err != nil { + return nil, nil, fmt.Errorf("splitting host and port: %w", err) + } + resolvedIPs, err := c.ResolveName(ctx, host) if err != nil { return nil, nil, fmt.Errorf("resolving name: %w", err) } else if len(resolvedIPs) == 0 { - return nil, nil, fmt.Errorf("no IP address found for name %q", domain) + return nil, nil, fmt.Errorf("no IP address found for name %q", host) } + portUint, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, nil, fmt.Errorf("parsing port: %w", err) + } + port := uint16(portUint) + errs := make([]error, 0, len(resolvedIPs)) for _, ip := range resolvedIPs { - httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip) + addrPort := netip.AddrPortFrom(ip, port) + httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort) if err != nil { errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) continue @@ -63,5 +71,5 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( return httpClient, cleanup, nil } - return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...)) + return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...)) } diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go index cac3fd38..54070c32 100644 --- a/internal/restrictednet/helpers_test.go +++ b/internal/restrictednet/helpers_test.go @@ -1,185 +1,5 @@ package restrictednet -import ( - "bufio" - "bytes" - "context" - "errors" - "io" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "strconv" - "sync" - "syscall" - "testing" - - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - func ptrTo[T any](value T) *T { return &value } - -func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport { - return &http.Transport{ - DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - go func() { - defer serverConn.Close() - - reader := bufio.NewReader(serverConn) - request, err := http.ReadRequest(reader) - if err != nil { - return - } - - response, err := handler(request.Host, request.Body) - if err != nil { - return - } - - // Read the response body and re-create it to avoid linting - // complaining that the response body must be closed. - responseData, err := io.ReadAll(response.Body) - if err != nil { - return - } - _ = response.Body.Close() - response.Body = io.NopCloser(bytes.NewReader(responseData)) - - _ = response.Write(serverConn) - }() - return clientConn, nil - }, - } -} - -func expectFirewallCallPair( - firewall *MockFirewall, - addContext context.Context, //nolint:revive - destinationIP netip.Addr, - destinationPort uint16, - addErr error, - removeErr error, -) { - destination := netip.AddrPortFrom(destinationIP, destinationPort) - sourceMatcher := listenAddrPortMatcher{} - - firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - addContext, "tcp", "eth0", sourceMatcher, destination, false, - ).DoAndReturn(func( - _ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, - ) error { - sourceMatcher.expected = source - return addErr - }) - - firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - context.Background(), "tcp", "eth0", sourceMatcher, destination, true, - ).Return(removeErr) -} - -func urlToHostnamePort(rawURL string, port uint16) string { - parsedURL, err := url.Parse(rawURL) - if err != nil { - panic(err) // programming error in test - } - parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10)) - return parsedURL.String() -} - -func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte { - t.Helper() - - queryData, err := io.ReadAll(queryReader) - require.NoError(t, err) - - query := new(dns.Msg) - err = query.Unpack(queryData) - require.NoError(t, err) - - response := new(dns.Msg) - response.SetReply(query) - response.Answer = append(response.Answer, answers...) - - wire, err := response.Pack() - require.NoError(t, err) - return wire -} - -func startTCPAccepter(t *testing.T) (port uint16) { - t.Helper() - - // Find a port available for both TCP IPv4 and TCP IPv6 - listeners := make([]net.Listener, 2) // IPv4 + IPv6 - netConfig := net.ListenConfig{} - var listenersToClose []net.Listener - for t.Context().Err() == nil { - // Find an available port for IPv4 - listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0) - listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) - require.NoError(t, err) - listeners[0] = listener - port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert - - // Check if that port is also available for IPv6 - listeningAddress = netip.AddrPortFrom( - netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), - port, - ) - listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) - if err == nil { - listeners[1] = listener - break // success, we found a port available for both IPv4 and IPv6 - } - var opErr *net.OpError - if errors.As(err, &opErr) { - var sysErr *os.SyscallError - if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) { - // Port found for IPv4 is already in use for IPv6, try another port - // We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS. - listenersToClose = append(listenersToClose, listeners[0]) - continue - } - } - } - - for _, listener := range listenersToClose { - err := listener.Close() - assert.NoError(t, err) - } - - var ready sync.WaitGroup - ready.Add(len(listeners)) - for _, listener := range listeners { - t.Cleanup(func() { - err := listener.Close() - assert.NoError(t, err) - }) - - go func() { - ready.Done() - for { - connection, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil { - return - } - assert.NoError(t, err) - return - } - err = connection.Close() - assert.NoError(t, err) - } - }() - } - - ready.Wait() - - return port -} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index d61f78d1..08ae7350 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -17,15 +17,13 @@ import ( // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. -func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, +func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort, ) (httpClient *http.Client, cleanup func() error, err error) { - fd, sourceAddrPort, err := bindSourceConnection(destinationIP) + fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr()) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } - destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort) - const remove = false err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) @@ -42,7 +40,8 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("connecting source socket: %w", err) } - httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection) + dial := makeDial(connection, destinationTLSName) + httpClient = newHTTPSClient(destinationTLSName, dial) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -53,7 +52,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) } err = connection.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("closing connection: %w", err)) } if len(errs) > 0 { @@ -64,21 +63,31 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client { - transport := baseTransport.Clone() - transport.Proxy = nil - transport.MaxIdleConns = 1 - transport.MaxIdleConnsPerHost = 1 - transport.MaxConnsPerHost = 1 - transport.IdleConnTimeout = time.Second - transport.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - ServerName: destinationTLSName, - } +type dialFunc func(ctx context.Context, network, address string) (net.Conn, error) +func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { + const timeout = 5 * time.Second + transport := &http.Transport{ + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: destinationTLSName, + }, + DialContext: dial, + } + return &http.Client{ + Timeout: timeout, + Transport: transport, + } +} + +func makeDial(connection net.Conn, tlsName string) dialFunc { _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) - expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) - transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { + expectedAddress := net.JoinHostPort(tlsName, destinationPort) + return func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -89,12 +98,6 @@ func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, co } return connection, nil } - - const timeout = 5 * time.Second - return &http.Client{ - Timeout: timeout, - Transport: transport, - } } func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index 02e36fd2..b488f505 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -2,12 +2,14 @@ package restrictednet import ( "context" - "net" + "fmt" + "net/http" "net/netip" "testing" "github.com/golang/mock/gomock" "github.com/qdm12/dns/v2/pkg/provider" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,31 +35,40 @@ func (m listenAddrPortMatcher) String() string { return "is a valid netip.AddrPort with a valid IP and non-zero port" } +type destinationAddrPortMatcher struct { + expected netip.AddrPort +} + +func (m destinationAddrPortMatcher) Matches(x any) bool { + ip, ok := x.(netip.AddrPort) + if !ok { + return false + } + if m.expected.IsValid() { + return ip == m.expected + } + return ip.IsValid() && ip.Port() == m.expected.Port() +} + +func (m destinationAddrPortMatcher) String() string { + if m.expected.IsValid() { + return "is the same as " + m.expected.String() + } + return "matches the port " + fmt.Sprint(m.expected.Port()) +} + func Test_Client_OpenHTTPS(t *testing.T) { t.Parallel() ctx := t.Context() - - netConfig := net.ListenConfig{} - listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:0") - require.NoError(t, err) - t.Cleanup(func() { - _ = listener.Close() - }) - listeningPort := uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert - go func() { - connection, acceptErr := listener.Accept() - if acceptErr == nil { - _ = connection.Close() - } - }() - ctrl := gomock.NewController(t) - firewall := NewMockFirewall(ctrl) - destination := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), listeningPort) + const destinationTLSName = "one.one.one.one" + destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) + + firewall := NewMockFirewall(ctrl) sourceMatcher := listenAddrPortMatcher{} firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - ctx, "tcp", "eth0", sourceMatcher, destination, false, + ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false, ).DoAndReturn(func(_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, ) error { @@ -65,7 +76,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - context.Background(), "tcp", "eth0", sourceMatcher, destination, true, + context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, ) const ipv6Supported = false @@ -77,13 +88,23 @@ func Test_Client_OpenHTTPS(t *testing.T) { UpstreamResolvers: upstreamResolvers, } client := New(settings) - client.httpsPort = listeningPort - httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) + httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort) require.NoError(t, err) require.NotNil(t, httpClient) require.NotNil(t, cleanup) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil) + require.NoError(t, err) + + response, err := httpClient.Do(request) + t.Cleanup(func() { + response.Body.Close() + }) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, response.StatusCode) + err = cleanup() require.NoError(t, err) } diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index 2feffeb5..8a15c39a 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -76,15 +76,17 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, dohServerIPs = append(dohServerIPs, dohServer.IPv4...) for _, dohServerIP := range dohServerIPs { - responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) + const defaultDoHPort = 443 + dohServerAddrPort := netip.AddrPortFrom(dohServerIP, defaultDoHPort) + responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort) switch { case err != nil: - errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w", - dohServer.URL, dohServerIP, err)) + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w", + dohServer.URL, dohServerAddrPort, err)) continue case responseMessage.Rcode != dns.RcodeSuccess: - errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s", - dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s", + dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode])) continue } addresses := answersToNetipAddrs(responseMessage) @@ -104,9 +106,9 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, } func (c *Client) doHQuery(ctx context.Context, queryWire []byte, - dohURL *url.URL, dohServerIP netip.Addr, + dohURL *url.URL, dohServerAddrPort netip.AddrPort, ) (responseMessage *dns.Msg, err error) { - httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP) + httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort) if err != nil { return nil, fmt.Errorf("opening https connection: %w", err) } diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index 972b5ff1..3ef4b847 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -1,15 +1,9 @@ package restrictednet import ( - "bytes" "context" - "errors" - "io" "net" - "net/http" "net/netip" - "net/url" - "sync/atomic" "testing" "github.com/golang/mock/gomock" @@ -21,320 +15,42 @@ import ( func Test_Client_ResolveName(t *testing.T) { t.Parallel() + ctx := t.Context() + ctrl := gomock.NewController(t) - testCases := map[string]struct { - ipv6Supported bool - upstreamResolvers []provider.Provider - expectedAddresses []netip.Addr - errorContains string - expectedDestIPs []netip.Addr - responder func(host string, requestBody io.Reader) (*http.Response, error) - }{ - "success_single_server_ipv4": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "fallback_between_servers": { - upstreamResolvers: []provider.Provider{ - { - DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }, - { - DoH: provider.DoHServer{ - URL: "https://resolver-2.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }, - }, - expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func(host string, requestBody io.Reader) (*http.Response, error) { - if host == "resolver-1.local" || - len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" { - return &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - } - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{2, 2, 2, 2}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "fallback_between_ips": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func() func(host string, requestBody io.Reader) (*http.Response, error) { - var calls atomic.Int32 - return func(_ string, requestBody io.Reader) (*http.Response, error) { - if calls.Add(1) == 1 { // first call fails - return &http.Response{ - StatusCode: http.StatusNotFound, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - } - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 2}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - } - }(), //nolint:bodyclose - }, - "dns_rcode_error_servfail": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - errorContains: "SERVFAIL", - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - queryWire, err := io.ReadAll(requestBody) - require.NoError(t, err) - query := new(dns.Msg) - err = query.Unpack(queryWire) - require.NoError(t, err) - response := new(dns.Msg) - response.SetReply(query) - response.Rcode = dns.RcodeServerFailure - wire, err := response.Pack() - require.NoError(t, err) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "no_answer": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: nil, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - wire := responseWireForQuery(t, requestBody) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "ipv6_preference": { - ipv6Supported: true, - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - IPv6: []netip.Addr{netip.MustParseAddr("::1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, - expectedDestIPs: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("::1"), - netip.MustParseAddr("127.0.0.1"), - }, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - queryWire, err := io.ReadAll(requestBody) - require.NoError(t, err) - query := new(dns.Msg) - err = query.Unpack(queryWire) - require.NoError(t, err) - if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA { - wire := responseWireForQuery(t, bytes.NewReader(queryWire)) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - } - wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "all_servers_fail": { - upstreamResolvers: []provider.Provider{ - {DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }}, - {DoH: provider.DoHServer{ - URL: "https://resolver-2.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }}, - }, - errorContains: "resolving host", - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, _ io.Reader) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - }, - }, + firewall := NewMockFirewall(ctrl) + sourceMatcher := listenAddrPortMatcher{} + destinationMatcher := destinationAddrPortMatcher{ + expected: netip.AddrPortFrom(netip.Addr{}, 443), } - for testName, testCase := range testCases { - t.Run(testName, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - firewall := NewMockFirewall(ctrl) - port := startTCPAccepter(t) - - for _, destinationIP := range testCase.expectedDestIPs { - expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil) - } - - resolvers := make([]provider.Provider, len(testCase.upstreamResolvers)) - copy(resolvers, testCase.upstreamResolvers) - for i := range resolvers { - resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port) - } - - settings := Settings{ - DefaultInterface: "eth0", - IPv6Supported: ptrTo(testCase.ipv6Supported), - Firewall: firewall, - UpstreamResolvers: resolvers, - BaseTransport: newInterceptTransport(testCase.responder), - } - client := New(settings) - client.httpsPort = port - - addresses, err := client.ResolveName(t.Context(), "github.com") - assert.Equal(t, testCase.expectedAddresses, addresses) - if testCase.errorContains != "" { - require.Error(t, err) - assert.ErrorContains(t, err, testCase.errorContains) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_Client_doHQuery(t *testing.T) { - t.Parallel() - - query := new(dns.Msg) - query.SetQuestion("example.com.", dns.TypeA) - queryWire, err := query.Pack() - require.NoError(t, err) - - responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, + // Add rule + firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false, + ).DoAndReturn(func( + _ context.Context, _, _ string, source, destination netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + destinationMatcher.expected = destination + return nil }) - testCases := map[string]struct { - response *http.Response - addFirewallRuleErr error - removeFirewallRuleErr error - errorContains string - expectedIPs []netip.Addr - }{ - "success": { - response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, - expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - }, - "http_status_not_ok": { - response: &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, - errorContains: "response status code is 502 Bad Gateway", - }, - "malformed_dns_response": { - response: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("not-dns")), - }, - errorContains: "parsing DoH response", - }, - "cleanup_error": { - response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, - removeFirewallRuleErr: errors.New("cleanup failed"), - errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed", - }, + // Removal rule + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true, + ).Return(nil).After(firstCall) + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(false), + Firewall: firewall, + UpstreamResolvers: []provider.Provider{provider.Cloudflare()}, } + client := New(settings) - for name, testCase := range testCases { - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - firewall := NewMockFirewall(ctrl) - port := startTCPAccepter(t) - - expectFirewallCallPair( - firewall, - context.Background(), - netip.MustParseAddr("127.0.0.1"), - port, - testCase.addFirewallRuleErr, - testCase.removeFirewallRuleErr, - ) - - settings := Settings{ - DefaultInterface: "eth0", - IPv6Supported: ptrTo(false), - Firewall: firewall, - UpstreamResolvers: []provider.Provider{provider.Google()}, - BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) { - return testCase.response, nil - }), - } - client := New(settings) - client.httpsPort = port - - dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port)) - require.NoError(t, err) - - message, err := client.doHQuery( - context.Background(), - queryWire, - dohURL, - netip.MustParseAddr("127.0.0.1"), - ) - - if testCase.errorContains != "" { - require.Error(t, err) - assert.ErrorContains(t, err, testCase.errorContains) - return - } - - require.NoError(t, err) - addresses := answersToNetipAddrs(message) - assert.Equal(t, testCase.expectedIPs, addresses) - }) - } + addresses, err := client.ResolveName(ctx, "github.com") + require.NoError(t, err) + assert.NotEmpty(t, addresses) } func Test_answersToNetipAddrs(t *testing.T) { diff --git a/internal/restrictednet/settings.go b/internal/restrictednet/settings.go index 4b943b52..52c678c3 100644 --- a/internal/restrictednet/settings.go +++ b/internal/restrictednet/settings.go @@ -2,7 +2,6 @@ package restrictednet import ( "errors" - "net/http" "github.com/qdm12/dns/v2/pkg/provider" ) @@ -12,13 +11,6 @@ type Settings struct { IPv6Supported *bool Firewall Firewall UpstreamResolvers []provider.Provider - BaseTransport *http.Transport -} - -func (s *Settings) setDefaults() { - if s.BaseTransport == nil { - s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert - } } func (s *Settings) validate() error {