This commit is contained in:
Quentin McGaw
2026-06-05 03:56:25 +00:00
parent ff6e45fae0
commit aa781c6cc5
12 changed files with 599 additions and 0 deletions
+56
View File
@@ -0,0 +1,56 @@
package restrictednet
import (
"context"
"fmt"
"net/http"
"github.com/qdm12/dns/v2/pkg/provider"
)
// Client is a client for making restricted network requests,
// such as opening temporary firewall rules for HTTPS connections.
// It is not meant to be high performance, although it can be used for
// multiple requests and concurrently.
type Client struct {
ipv6Supported bool
firewall Firewall
outboundInterface string
dohServers []provider.DoHServer
}
func New(firewall Firewall, defaultInterface string, ipv6Supported bool,
upstreamResolvers []provider.Provider,
) (*Client, error) {
dohServers := make([]provider.DoHServer, len(upstreamResolvers))
for i, upstreamResolver := range upstreamResolvers {
dohServers[i] = upstreamResolver.DoH
}
return &Client{
firewall: firewall,
outboundInterface: defaultInterface,
ipv6Supported: ipv6Supported,
dohServers: dohServers,
}, nil
}
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
httpClient *http.Client, cleanup func() error, err error,
) {
resolvedIPs, err := c.ResolveName(ctx, domain)
if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
}
selectedIP := resolvedIPs[0]
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
}
return httpClient, cleanup, nil
}
+68
View File
@@ -0,0 +1,68 @@
package restrictednet
import (
"context"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/require"
)
type listenAddrPortMatcher struct {
expected netip.AddrPort
}
func (m listenAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
}
func (m listenAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "is a valid netip.AddrPort with a valid IP and non-zero port"
}
func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
destination := netip.MustParseAddrPort("1.2.3.4:443")
backgroundContext := context.Background()
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
)
const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()}
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
require.NoError(t, err)
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)
err = cleanup()
require.NoError(t, err)
}
+115
View File
@@ -0,0 +1,115 @@
package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = listener.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = listener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
httpTransport.MaxIdleConns = 1
httpTransport.MaxIdleConnsPerHost = 1
httpTransport.IdleConnTimeout = time.Second
httpTransport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: httpTransport,
}
}
func newBoundDialContext(destinationAddress netip.Addr,
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil
}
}
func bindSourcePort(destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
bindAddr = netip.AddrFrom4([4]byte{})
} else {
bindAddr = netip.AddrFrom16([16]byte{})
}
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
netip.AddrPortFrom(bindAddr, 0)))
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()
return listener, sourceAddr, nil
}
+12
View File
@@ -0,0 +1,12 @@
package restrictednet
import (
"context"
"net/netip"
)
type Firewall interface {
AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error
}
@@ -0,0 +1,3 @@
package restrictednet
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
+50
View File
@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
// Package restrictednet is a generated GoMock package.
package restrictednet
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutputFromIPPortToIPPort mocks base method.
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
}
+177
View File
@@ -0,0 +1,177 @@
package restrictednet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"github.com/miekg/dns"
)
// ResolveName resolves the given host name to IP addresses using DoH servers,
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
// The host must be a single well-formed domain name, without port or path.
func (c *Client) ResolveName(ctx context.Context, host string) (
resolvedAddresses []netip.Addr, err error,
) {
questionTypes := make([]uint16, 0, 2)
if c.ipv6Supported {
questionTypes = append(questionTypes, dns.TypeAAAA)
}
questionTypes = append(questionTypes, dns.TypeA)
var addresses []netip.Addr
errs := make([]error, 0, len(questionTypes))
for _, questionType := range questionTypes {
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
if err != nil {
errs = append(errs, err)
continue
}
addresses = append(addresses, answerAddresses...)
}
switch {
case len(addresses) > 0:
return addresses, nil
case len(errs) == 0:
return nil, nil // no address found
default: // errors
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
}
}
func (c *Client) resolveOneQuestionType(ctx context.Context,
host string, questionType uint16,
) (addresses []netip.Addr, err error) {
queryMessage := &dns.Msg{}
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
queryWire, err := queryMessage.Pack()
if err != nil {
return nil, fmt.Errorf("packing DNS query: %w", err)
}
// Try every DoH server and every of each of their IP until we get a non-empty
// successful response.
errs := make([]error, 0)
for _, dohServer := range c.dohServers {
dohURL, err := url.Parse(dohServer.URL)
if err != nil {
errs = append(errs,
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
continue
}
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
if c.ipv6Supported {
// Prefer IPv6 addresses if IPv6 is supported
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
}
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
for _, dohServerIP := range dohServerIPs {
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP)
switch {
case err != nil:
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w",
dohServer.URL, dohServerIP, err))
continue
case responseMessage.Rcode != dns.RcodeSuccess:
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s",
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode]))
continue
}
addresses := answersToNetipAddrs(responseMessage)
if len(addresses) == 0 {
continue
}
return addresses, nil
}
}
if len(errs) == 0 {
return nil, nil
}
return nil, fmt.Errorf("resolving %s %s: %w",
dns.TypeToString[questionType], host, errors.Join(errs...))
}
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerIP netip.Addr,
) (responseMessage *dns.Msg, err error) {
httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err)
}
defer func() {
closeErr := close()
if err == nil && closeErr != nil {
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
}
}()
requestBody := bytes.NewReader(queryWire)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
request.Header.Set("Content-Type", "application/dns-message")
request.Header.Set("Accept", "application/dns-message")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
responseData, err := io.ReadAll(response.Body)
if err != nil {
_ = response.Body.Close()
return nil, fmt.Errorf("reading response body: %w", err)
}
err = response.Body.Close()
if err != nil {
return nil, fmt.Errorf("closing response body: %w", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response status code is %s, data: %s",
response.Status, responseData)
}
responseMessage = new(dns.Msg)
err = responseMessage.Unpack(responseData)
if err != nil {
return nil, fmt.Errorf("parsing DoH response: %w", err)
}
return responseMessage, nil
}
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
if message == nil {
return nil
}
addresses = make([]netip.Addr, 0, len(message.Answer))
for _, answer := range message.Answer {
switch record := answer.(type) {
case *dns.A:
address, ok := netip.AddrFromSlice(record.A)
if ok {
addresses = append(addresses, address.Unmap())
}
case *dns.AAAA:
address, ok := netip.AddrFromSlice(record.AAAA)
if ok {
addresses = append(addresses, address)
}
}
}
return addresses
}
+82
View File
@@ -0,0 +1,82 @@
package restrictednet
import (
"net"
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func Test_answersToNetipAddrs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
message *dns.Msg
expected []netip.Addr
errorIsNil bool
}{
"nil_message": {
message: nil,
expected: nil,
errorIsNil: true,
},
"no_answers": {
message: &dns.Msg{},
expected: []netip.Addr{},
errorIsNil: true,
},
"a_record": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
errorIsNil: true,
},
"aaaa_record": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
errorIsNil: true,
},
"mixed_records": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
errorIsNil: true,
},
}
for testName, testCase := range testCases {
testCase := testCase
t.Run(testName, func(t *testing.T) {
t.Parallel()
addresses := answersToNetipAddrs(testCase.message)
assert.Equal(t, testCase.expected, addresses)
})
}
}