diff --git a/AGENTS.md b/AGENTS.md index b7d0b3bb..fb3f7d8e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -116,6 +116,7 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool. - **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example. - **Always** set the `.Return(...)` on the mock if the function returns something. - Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency + - Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests. ### main.go diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 9225a96c..7b1547a6 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -14,30 +14,31 @@ import ( // It is not meant to be high performance, although it can be used for // multiple requests and concurrently. type Client struct { + outboundInterface string ipv6Supported bool firewall Firewall - outboundInterface string dohServers []provider.DoHServer + baseTransport *http.Transport httpsPort uint16 } -func New(firewall Firewall, defaultInterface string, ipv6Supported bool, - upstreamResolvers []provider.Provider, -) *Client { - if len(upstreamResolvers) == 0 { - panic("no upstream resolvers provided") // programming error +func New(settings Settings) *Client { + settings.setDefaults() + if err := settings.validate(); err != nil { + panic(fmt.Sprintf("invalid settings: %v", err)) // programming error } - dohServers := make([]provider.DoHServer, len(upstreamResolvers)) - for i, upstreamResolver := range upstreamResolvers { + dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers)) + for i, upstreamResolver := range settings.UpstreamResolvers { dohServers[i] = upstreamResolver.DoH } const defaultHTTPSPort = 443 return &Client{ - firewall: firewall, - outboundInterface: defaultInterface, - ipv6Supported: ipv6Supported, + outboundInterface: settings.DefaultInterface, + ipv6Supported: *settings.IPv6Supported, + firewall: settings.Firewall, dohServers: dohServers, + baseTransport: settings.BaseTransport, httpsPort: defaultHTTPSPort, } } diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go new file mode 100644 index 00000000..cac3fd38 --- /dev/null +++ b/internal/restrictednet/helpers_test.go @@ -0,0 +1,185 @@ +package restrictednet + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strconv" + "sync" + "syscall" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func ptrTo[T any](value T) *T { + return &value +} + +func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport { + return &http.Transport{ + DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + + reader := bufio.NewReader(serverConn) + request, err := http.ReadRequest(reader) + if err != nil { + return + } + + response, err := handler(request.Host, request.Body) + if err != nil { + return + } + + // Read the response body and re-create it to avoid linting + // complaining that the response body must be closed. + responseData, err := io.ReadAll(response.Body) + if err != nil { + return + } + _ = response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseData)) + + _ = response.Write(serverConn) + }() + return clientConn, nil + }, + } +} + +func expectFirewallCallPair( + firewall *MockFirewall, + addContext context.Context, //nolint:revive + destinationIP netip.Addr, + destinationPort uint16, + addErr error, + removeErr error, +) { + destination := netip.AddrPortFrom(destinationIP, destinationPort) + sourceMatcher := listenAddrPortMatcher{} + + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + addContext, "tcp", "eth0", sourceMatcher, destination, false, + ).DoAndReturn(func( + _ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + return addErr + }) + + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destination, true, + ).Return(removeErr) +} + +func urlToHostnamePort(rawURL string, port uint16) string { + parsedURL, err := url.Parse(rawURL) + if err != nil { + panic(err) // programming error in test + } + parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10)) + return parsedURL.String() +} + +func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte { + t.Helper() + + queryData, err := io.ReadAll(queryReader) + require.NoError(t, err) + + query := new(dns.Msg) + err = query.Unpack(queryData) + require.NoError(t, err) + + response := new(dns.Msg) + response.SetReply(query) + response.Answer = append(response.Answer, answers...) + + wire, err := response.Pack() + require.NoError(t, err) + return wire +} + +func startTCPAccepter(t *testing.T) (port uint16) { + t.Helper() + + // Find a port available for both TCP IPv4 and TCP IPv6 + listeners := make([]net.Listener, 2) // IPv4 + IPv6 + netConfig := net.ListenConfig{} + var listenersToClose []net.Listener + for t.Context().Err() == nil { + // Find an available port for IPv4 + listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0) + listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) + require.NoError(t, err) + listeners[0] = listener + port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert + + // Check if that port is also available for IPv6 + listeningAddress = netip.AddrPortFrom( + netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), + port, + ) + listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) + if err == nil { + listeners[1] = listener + break // success, we found a port available for both IPv4 and IPv6 + } + var opErr *net.OpError + if errors.As(err, &opErr) { + var sysErr *os.SyscallError + if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) { + // Port found for IPv4 is already in use for IPv6, try another port + // We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS. + listenersToClose = append(listenersToClose, listeners[0]) + continue + } + } + } + + for _, listener := range listenersToClose { + err := listener.Close() + assert.NoError(t, err) + } + + var ready sync.WaitGroup + ready.Add(len(listeners)) + for _, listener := range listeners { + t.Cleanup(func() { + err := listener.Close() + assert.NoError(t, err) + }) + + go func() { + ready.Done() + for { + connection, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil { + return + } + assert.NoError(t, err) + return + } + err = connection.Close() + assert.NoError(t, err) + } + }() + } + + ready.Wait() + + return port +} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 06c378ce..d61f78d1 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -42,7 +42,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("connecting source socket: %w", err) } - httpClient = newHTTPSClient(destinationTLSName, connection) + httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -64,21 +64,21 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client { - httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert - httpTransport.Proxy = nil - httpTransport.MaxIdleConns = 1 - httpTransport.MaxIdleConnsPerHost = 1 - httpTransport.MaxConnsPerHost = 1 - httpTransport.IdleConnTimeout = time.Second - httpTransport.TLSClientConfig = &tls.Config{ +func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client { + transport := baseTransport.Clone() + transport.Proxy = nil + transport.MaxIdleConns = 1 + transport.MaxIdleConnsPerHost = 1 + transport.MaxConnsPerHost = 1 + transport.IdleConnTimeout = time.Second + transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) - httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { + transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -93,7 +93,7 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client const timeout = 5 * time.Second return &http.Client{ Timeout: timeout, - Transport: httpTransport, + Transport: transport, } } diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index 7db81e60..02e36fd2 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -70,8 +70,13 @@ func Test_Client_OpenHTTPS(t *testing.T) { const ipv6Supported = false upstreamResolvers := []provider.Provider{provider.Google()} - client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) - require.NoError(t, err) + settings := Settings{ + Firewall: firewall, + DefaultInterface: "eth0", + IPv6Supported: ptrTo(ipv6Supported), + UpstreamResolvers: upstreamResolvers, + } + client := New(settings) client.httpsPort = listeningPort httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index aa1c3e64..2feffeb5 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -79,11 +79,11 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) switch { case err != nil: - errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w", + errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w", dohServer.URL, dohServerIP, err)) continue case responseMessage.Rcode != dns.RcodeSuccess: - errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s", + errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s", dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) continue } diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index 51762778..972b5ff1 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -1,80 +1,391 @@ package restrictednet import ( + "bytes" + "context" + "errors" + "io" "net" + "net/http" "net/netip" + "net/url" + "sync/atomic" "testing" + "github.com/golang/mock/gomock" "github.com/miekg/dns" + "github.com/qdm12/dns/v2/pkg/provider" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func Test_answersToNetipAddrs(t *testing.T) { +func Test_Client_ResolveName(t *testing.T) { t.Parallel() testCases := map[string]struct { - message *dns.Msg - expected []netip.Addr - errorIsNil bool + ipv6Supported bool + upstreamResolvers []provider.Provider + expectedAddresses []netip.Addr + errorContains string + expectedDestIPs []netip.Addr + responder func(host string, requestBody io.Reader) (*http.Response, error) }{ - "nil_message": { - message: nil, - expected: nil, - errorIsNil: true, + "success_single_server_ipv4": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, }, - "no_answers": { - message: &dns.Msg{}, - expected: []netip.Addr{}, - errorIsNil: true, - }, - "a_record": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, + "fallback_between_servers": { + upstreamResolvers: []provider.Provider{ + { + DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }, + { + DoH: provider.DoHServer{ + URL: "https://resolver-2.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, }, }, - expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - errorIsNil: true, - }, - "aaaa_record": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.AAAA{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }, - }, + expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func(host string, requestBody io.Reader) (*http.Response, error) { + if host == "resolver-1.local" || + len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil + } + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{2, 2, 2, 2}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil }, - expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, - errorIsNil: true, }, - "mixed_records": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, - }, - &dns.AAAA{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }, + "fallback_between_ips": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func() func(host string, requestBody io.Reader) (*http.Response, error) { + var calls atomic.Int32 + return func(_ string, requestBody io.Reader) (*http.Response, error) { + if calls.Add(1) == 1 { // first call fails + return &http.Response{ + StatusCode: http.StatusNotFound, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil + } + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 2}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + } + }(), //nolint:bodyclose + }, + "dns_rcode_error_servfail": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }}, + errorContains: "SERVFAIL", + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + queryWire, err := io.ReadAll(requestBody) + require.NoError(t, err) + query := new(dns.Msg) + err = query.Unpack(queryWire) + require.NoError(t, err) + response := new(dns.Msg) + response.SetReply(query) + response.Rcode = dns.RcodeServerFailure + wire, err := response.Pack() + require.NoError(t, err) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, + }, + "no_answer": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }}, + expectedAddresses: nil, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + wire := responseWireForQuery(t, requestBody) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, + }, + "ipv6_preference": { + ipv6Supported: true, + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + IPv6: []netip.Addr{netip.MustParseAddr("::1")}, + }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + expectedDestIPs: []netip.Addr{ + netip.MustParseAddr("::1"), + netip.MustParseAddr("::1"), + netip.MustParseAddr("127.0.0.1"), + }, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + queryWire, err := io.ReadAll(requestBody) + require.NoError(t, err) + query := new(dns.Msg) + err = query.Unpack(queryWire) + require.NoError(t, err) + if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA { + wire := responseWireForQuery(t, bytes.NewReader(queryWire)) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + } + wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, + }, + "all_servers_fail": { + upstreamResolvers: []provider.Provider{ + {DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }}, + {DoH: provider.DoHServer{ + URL: "https://resolver-2.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }}, + }, + errorContains: "resolving host", + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, _ io.Reader) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil }, - expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, - errorIsNil: true, }, } for testName, testCase := range testCases { t.Run(testName, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + firewall := NewMockFirewall(ctrl) + port := startTCPAccepter(t) + + for _, destinationIP := range testCase.expectedDestIPs { + expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil) + } + + resolvers := make([]provider.Provider, len(testCase.upstreamResolvers)) + copy(resolvers, testCase.upstreamResolvers) + for i := range resolvers { + resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port) + } + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(testCase.ipv6Supported), + Firewall: firewall, + UpstreamResolvers: resolvers, + BaseTransport: newInterceptTransport(testCase.responder), + } + client := New(settings) + client.httpsPort = port + + addresses, err := client.ResolveName(t.Context(), "github.com") + assert.Equal(t, testCase.expectedAddresses, addresses) + if testCase.errorContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, testCase.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_Client_doHQuery(t *testing.T) { + t.Parallel() + + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + queryWire, err := query.Pack() + require.NoError(t, err) + + responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }) + + testCases := map[string]struct { + response *http.Response + addFirewallRuleErr error + removeFirewallRuleErr error + errorContains string + expectedIPs []netip.Addr + }{ + "success": { + response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, + expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + }, + "http_status_not_ok": { + response: &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, + errorContains: "response status code is 502 Bad Gateway", + }, + "malformed_dns_response": { + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("not-dns")), + }, + errorContains: "parsing DoH response", + }, + "cleanup_error": { + response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, + removeFirewallRuleErr: errors.New("cleanup failed"), + errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + firewall := NewMockFirewall(ctrl) + port := startTCPAccepter(t) + + expectFirewallCallPair( + firewall, + context.Background(), + netip.MustParseAddr("127.0.0.1"), + port, + testCase.addFirewallRuleErr, + testCase.removeFirewallRuleErr, + ) + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(false), + Firewall: firewall, + UpstreamResolvers: []provider.Provider{provider.Google()}, + BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) { + return testCase.response, nil + }), + } + client := New(settings) + client.httpsPort = port + + dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port)) + require.NoError(t, err) + + message, err := client.doHQuery( + context.Background(), + queryWire, + dohURL, + netip.MustParseAddr("127.0.0.1"), + ) + + if testCase.errorContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, testCase.errorContains) + return + } + + require.NoError(t, err) + addresses := answersToNetipAddrs(message) + assert.Equal(t, testCase.expectedIPs, addresses) + }) + } +} + +func Test_answersToNetipAddrs(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + message *dns.Msg + expected []netip.Addr + }{ + "nil_message": {}, + "no_answers": { + message: &dns.Msg{}, + expected: []netip.Addr{}, + }, + "a_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + }, + "aaaa_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + }, + "mixed_records": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + t.Parallel() addresses := answersToNetipAddrs(testCase.message) - assert.Equal(t, testCase.expected, addresses) }) } diff --git a/internal/restrictednet/settings.go b/internal/restrictednet/settings.go new file mode 100644 index 00000000..4b943b52 --- /dev/null +++ b/internal/restrictednet/settings.go @@ -0,0 +1,36 @@ +package restrictednet + +import ( + "errors" + "net/http" + + "github.com/qdm12/dns/v2/pkg/provider" +) + +type Settings struct { + DefaultInterface string + IPv6Supported *bool + Firewall Firewall + UpstreamResolvers []provider.Provider + BaseTransport *http.Transport +} + +func (s *Settings) setDefaults() { + if s.BaseTransport == nil { + s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert + } +} + +func (s *Settings) validate() error { + switch { + case s.DefaultInterface == "": + return errors.New("default interface is not set") + case s.IPv6Supported == nil: + return errors.New("IPv6 support field is not set") + case s.Firewall == nil: + return errors.New("firewall is not set") + case len(s.UpstreamResolvers) == 0: + return errors.New("no upstream resolvers provided") + } + return nil +}