mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 13:27:31 +02:00
initial
This commit is contained in:
@@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
|
||||
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
|
||||
- Use `netip` types instead of `net` types whenever possible
|
||||
- Use constants instead of variables whenever possible, especially function-local inline constants.
|
||||
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
|
||||
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
|
||||
- `panic`:
|
||||
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
|
||||
@@ -127,6 +128,7 @@ The Go formatter used is gofumpt.
|
||||
### Errors
|
||||
|
||||
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
|
||||
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
|
||||
- In rare cases, you can just use `return err` notably:
|
||||
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
|
||||
- If the current function only statement is the call to another function, for example:
|
||||
@@ -179,6 +181,8 @@ The Go formatter used is gofumpt.
|
||||
|
||||
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
|
||||
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
|
||||
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
|
||||
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
|
||||
|
||||
## Validation checklist
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
|
||||
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
|
||||
AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
ip netip.Addr, port uint16, remove bool) error
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
|
||||
source, destination netip.AddrPort, remove bool) error
|
||||
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
|
||||
subnet netip.Prefix, remove bool) error
|
||||
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
|
||||
|
||||
@@ -177,6 +177,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
if source.Addr().BitLen() != destination.Addr().BitLen() {
|
||||
return fmt.Errorf("source and destination address families do not match")
|
||||
}
|
||||
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(),
|
||||
interfaceFlag, protocol, protocol, destination.Port())
|
||||
if destination.Addr().Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
|
||||
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
|
||||
// If remove is true, the rule is removed instead of added.
|
||||
|
||||
@@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
|
||||
) error {
|
||||
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error {
|
||||
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
|
||||
source, destination, remove)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
)
|
||||
|
||||
// Client is a client for making restricted network requests,
|
||||
// such as opening temporary firewall rules for HTTPS connections.
|
||||
// It is not meant to be high performance, although it can be used for
|
||||
// multiple requests and concurrently.
|
||||
type Client struct {
|
||||
ipv6Supported bool
|
||||
firewall Firewall
|
||||
outboundInterface string
|
||||
dohServers []provider.DoHServer
|
||||
}
|
||||
|
||||
func New(firewall Firewall, defaultInterface string, ipv6Supported bool,
|
||||
upstreamResolvers []provider.Provider,
|
||||
) (*Client, error) {
|
||||
dohServers := make([]provider.DoHServer, len(upstreamResolvers))
|
||||
for i, upstreamResolver := range upstreamResolvers {
|
||||
dohServers[i] = upstreamResolver.DoH
|
||||
}
|
||||
|
||||
return &Client{
|
||||
firewall: firewall,
|
||||
outboundInterface: defaultInterface,
|
||||
ipv6Supported: ipv6Supported,
|
||||
dohServers: dohServers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
|
||||
httpClient *http.Client, cleanup func() error, err error,
|
||||
) {
|
||||
resolvedIPs, err := c.ResolveName(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving name: %w", err)
|
||||
} else if len(resolvedIPs) == 0 {
|
||||
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
|
||||
}
|
||||
|
||||
selectedIP := resolvedIPs[0]
|
||||
|
||||
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
|
||||
}
|
||||
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/dns/v2/pkg/provider"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type listenAddrPortMatcher struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) Matches(x any) bool {
|
||||
ip, ok := x.(netip.AddrPort)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expected.IsValid() {
|
||||
return ip == m.expected
|
||||
}
|
||||
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
|
||||
}
|
||||
|
||||
func (m listenAddrPortMatcher) String() string {
|
||||
if m.expected.IsValid() {
|
||||
return "is the same as " + m.expected.String()
|
||||
}
|
||||
return "is a valid netip.AddrPort with a valid IP and non-zero port"
|
||||
}
|
||||
|
||||
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
|
||||
destination := netip.MustParseAddrPort("1.2.3.4:443")
|
||||
backgroundContext := context.Background()
|
||||
sourceMatcher := listenAddrPortMatcher{}
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
|
||||
).DoAndReturn(func(_ context.Context,
|
||||
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||
) error {
|
||||
sourceMatcher.expected = source
|
||||
return nil
|
||||
})
|
||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
|
||||
)
|
||||
|
||||
const ipv6Supported = false
|
||||
upstreamResolvers := []provider.Provider{provider.Google()}
|
||||
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, httpClient)
|
||||
require.NotNil(t, cleanup)
|
||||
|
||||
err = cleanup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
||||
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||
}
|
||||
|
||||
const httpsPort = 443
|
||||
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
|
||||
|
||||
const remove = false
|
||||
ctx := context.Background() // it's a quick firewall change, worth not passing a context
|
||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
||||
}
|
||||
|
||||
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
|
||||
cleanup = func() error {
|
||||
var errs []error
|
||||
httpClient.CloseIdleConnections()
|
||||
const remove = true
|
||||
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||
sourceAddrPort, destinationAddrPort, remove)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||
}
|
||||
err = listener.Close()
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("closing listener: %w", err))
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return httpClient, cleanup, nil
|
||||
}
|
||||
|
||||
func newHTTPSClient(destinationTLSName string,
|
||||
destinationIP netip.Addr, sourceAddress netip.AddrPort,
|
||||
) *http.Client {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
|
||||
httpTransport.Proxy = nil
|
||||
httpTransport.MaxIdleConns = 1
|
||||
httpTransport.MaxIdleConnsPerHost = 1
|
||||
httpTransport.IdleConnTimeout = time.Second
|
||||
httpTransport.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: destinationTLSName,
|
||||
}
|
||||
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
|
||||
|
||||
const timeout = 5 * time.Second
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
}
|
||||
|
||||
func newBoundDialContext(destinationAddress netip.Addr,
|
||||
sourceAddress netip.AddrPort,
|
||||
) func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
const httpsPort = 443
|
||||
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
|
||||
return func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
const timeout = 2 * time.Second
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
|
||||
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bindSourcePort(destinationIP netip.Addr) (
|
||||
listener net.Listener, sourceAddr netip.AddrPort, err error,
|
||||
) {
|
||||
var bindAddr netip.Addr
|
||||
if destinationIP.Is4() {
|
||||
bindAddr = netip.AddrFrom4([4]byte{})
|
||||
} else {
|
||||
bindAddr = netip.AddrFrom16([16]byte{})
|
||||
}
|
||||
|
||||
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
|
||||
netip.AddrPortFrom(bindAddr, 0)))
|
||||
if err != nil {
|
||||
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
|
||||
}
|
||||
|
||||
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
|
||||
sourceAddr = tcpAddr.AddrPort()
|
||||
|
||||
return listener, sourceAddr, nil
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Firewall interface {
|
||||
AcceptOutputFromIPPortToIPPort(ctx context.Context,
|
||||
protocol, intf string, source, destination netip.AddrPort, remove bool,
|
||||
) error
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package restrictednet
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||
@@ -0,0 +1,50 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
|
||||
|
||||
// Package restrictednet is a generated GoMock package.
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
context "context"
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFirewall is a mock of Firewall interface.
|
||||
type MockFirewall struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFirewallMockRecorder
|
||||
}
|
||||
|
||||
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||
type MockFirewallMockRecorder struct {
|
||||
mock *MockFirewall
|
||||
}
|
||||
|
||||
// NewMockFirewall creates a new mock instance.
|
||||
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||
mock := &MockFirewall{ctrl: ctrl}
|
||||
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort mocks base method.
|
||||
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
|
||||
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ResolveName resolves the given host name to IP addresses using DoH servers,
|
||||
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
|
||||
// The host must be a single well-formed domain name, without port or path.
|
||||
func (c *Client) ResolveName(ctx context.Context, host string) (
|
||||
resolvedAddresses []netip.Addr, err error,
|
||||
) {
|
||||
questionTypes := make([]uint16, 0, 2)
|
||||
if c.ipv6Supported {
|
||||
questionTypes = append(questionTypes, dns.TypeAAAA)
|
||||
}
|
||||
questionTypes = append(questionTypes, dns.TypeA)
|
||||
|
||||
var addresses []netip.Addr
|
||||
errs := make([]error, 0, len(questionTypes))
|
||||
for _, questionType := range questionTypes {
|
||||
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
addresses = append(addresses, answerAddresses...)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(addresses) > 0:
|
||||
return addresses, nil
|
||||
case len(errs) == 0:
|
||||
return nil, nil // no address found
|
||||
default: // errors
|
||||
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resolveOneQuestionType(ctx context.Context,
|
||||
host string, questionType uint16,
|
||||
) (addresses []netip.Addr, err error) {
|
||||
queryMessage := &dns.Msg{}
|
||||
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
|
||||
queryWire, err := queryMessage.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("packing DNS query: %w", err)
|
||||
}
|
||||
|
||||
// Try every DoH server and every of each of their IP until we get a non-empty
|
||||
// successful response.
|
||||
errs := make([]error, 0)
|
||||
for _, dohServer := range c.dohServers {
|
||||
dohURL, err := url.Parse(dohServer.URL)
|
||||
if err != nil {
|
||||
errs = append(errs,
|
||||
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
|
||||
continue
|
||||
}
|
||||
|
||||
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
|
||||
if c.ipv6Supported {
|
||||
// Prefer IPv6 addresses if IPv6 is supported
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
|
||||
}
|
||||
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
|
||||
|
||||
for _, dohServerIP := range dohServerIPs {
|
||||
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP)
|
||||
switch {
|
||||
case err != nil:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w",
|
||||
dohServer.URL, dohServerIP, err))
|
||||
continue
|
||||
case responseMessage.Rcode != dns.RcodeSuccess:
|
||||
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s",
|
||||
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode]))
|
||||
continue
|
||||
}
|
||||
addresses := answersToNetipAddrs(responseMessage)
|
||||
if len(addresses) == 0 {
|
||||
continue
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("resolving %s %s: %w",
|
||||
dns.TypeToString[questionType], host, errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||
dohURL *url.URL, dohServerIP netip.Addr,
|
||||
) (responseMessage *dns.Msg, err error) {
|
||||
httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := close()
|
||||
if err == nil && closeErr != nil {
|
||||
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
requestBody := bytes.NewReader(queryWire)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/dns-message")
|
||||
request.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responseData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
_ = response.Body.Close()
|
||||
return nil, fmt.Errorf("reading response body: %w", err)
|
||||
}
|
||||
|
||||
err = response.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("closing response body: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("response status code is %s, data: %s",
|
||||
response.Status, responseData)
|
||||
}
|
||||
|
||||
responseMessage = new(dns.Msg)
|
||||
err = responseMessage.Unpack(responseData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing DoH response: %w", err)
|
||||
}
|
||||
|
||||
return responseMessage, nil
|
||||
}
|
||||
|
||||
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
addresses = make([]netip.Addr, 0, len(message.Answer))
|
||||
for _, answer := range message.Answer {
|
||||
switch record := answer.(type) {
|
||||
case *dns.A:
|
||||
address, ok := netip.AddrFromSlice(record.A)
|
||||
if ok {
|
||||
addresses = append(addresses, address.Unmap())
|
||||
}
|
||||
case *dns.AAAA:
|
||||
address, ok := netip.AddrFromSlice(record.AAAA)
|
||||
if ok {
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_answersToNetipAddrs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
message *dns.Msg
|
||||
expected []netip.Addr
|
||||
errorIsNil bool
|
||||
}{
|
||||
"nil_message": {
|
||||
message: nil,
|
||||
expected: nil,
|
||||
errorIsNil: true,
|
||||
},
|
||||
"no_answers": {
|
||||
message: &dns.Msg{},
|
||||
expected: []netip.Addr{},
|
||||
errorIsNil: true,
|
||||
},
|
||||
"a_record": {
|
||||
message: &dns.Msg{
|
||||
Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||
errorIsNil: true,
|
||||
},
|
||||
"aaaa_record": {
|
||||
message: &dns.Msg{
|
||||
Answer: []dns.RR{
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
errorIsNil: true,
|
||||
},
|
||||
"mixed_records": {
|
||||
message: &dns.Msg{
|
||||
Answer: []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
|
||||
A: net.IP{1, 1, 1, 1},
|
||||
},
|
||||
&dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
|
||||
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
|
||||
errorIsNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addresses := answersToNetipAddrs(testCase.message)
|
||||
|
||||
assert.Equal(t, testCase.expected, addresses)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user