This commit is contained in:
Quentin McGaw
2026-06-05 03:56:25 +00:00
parent ff6e45fae0
commit aa781c6cc5
12 changed files with 599 additions and 0 deletions
+4
View File
@@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element - Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
- Use `netip` types instead of `net` types whenever possible - Use `netip` types instead of `net` types whenever possible
- Use constants instead of variables whenever possible, especially function-local inline constants. - Use constants instead of variables whenever possible, especially function-local inline constants.
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation - Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
- `panic`: - `panic`:
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects) - should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
@@ -127,6 +128,7 @@ The Go formatter used is gofumpt.
### Errors ### Errors
- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)` - Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
- In rare cases, you can just use `return err` notably: - In rare cases, you can just use `return err` notably:
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion - If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
- If the current function only statement is the call to another function, for example: - If the current function only statement is the call to another function, for example:
@@ -179,6 +181,8 @@ The Go formatter used is gofumpt.
- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections. - Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default. - Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.
## Validation checklist ## Validation checklist
+2
View File
@@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutput(ctx context.Context, protocol, intf string, AcceptOutput(ctx context.Context, protocol, intf string,
ip netip.Addr, port uint16, remove bool) error ip netip.Addr, port uint16, remove bool) error
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
source, destination netip.AddrPort, remove bool) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr, AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
+23
View File
@@ -177,6 +177,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
if source.Addr().BitLen() != destination.Addr().BitLen() {
return fmt.Errorf("source and destination address families do not match")
}
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(),
interfaceFlag, protocol, protocol, destination.Port())
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet // AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces. // on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added. // If remove is true, the rule is removed instead of added.
+7
View File
@@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
) error { ) error {
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove) return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
} }
func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
source, destination, remove)
}
+56
View File
@@ -0,0 +1,56 @@
package restrictednet
import (
"context"
"fmt"
"net/http"
"github.com/qdm12/dns/v2/pkg/provider"
)
// Client is a client for making restricted network requests,
// such as opening temporary firewall rules for HTTPS connections.
// It is not meant to be high performance, although it can be used for
// multiple requests and concurrently.
type Client struct {
ipv6Supported bool
firewall Firewall
outboundInterface string
dohServers []provider.DoHServer
}
func New(firewall Firewall, defaultInterface string, ipv6Supported bool,
upstreamResolvers []provider.Provider,
) (*Client, error) {
dohServers := make([]provider.DoHServer, len(upstreamResolvers))
for i, upstreamResolver := range upstreamResolvers {
dohServers[i] = upstreamResolver.DoH
}
return &Client{
firewall: firewall,
outboundInterface: defaultInterface,
ipv6Supported: ipv6Supported,
dohServers: dohServers,
}, nil
}
func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
httpClient *http.Client, cleanup func() error, err error,
) {
resolvedIPs, err := c.ResolveName(ctx, domain)
if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
}
selectedIP := resolvedIPs[0]
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
}
return httpClient, cleanup, nil
}
+68
View File
@@ -0,0 +1,68 @@
package restrictednet
import (
"context"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/require"
)
type listenAddrPortMatcher struct {
expected netip.AddrPort
}
func (m listenAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
}
func (m listenAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "is a valid netip.AddrPort with a valid IP and non-zero port"
}
func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
destination := netip.MustParseAddrPort("1.2.3.4:443")
backgroundContext := context.Background()
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
)
const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()}
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
require.NoError(t, err)
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)
err = cleanup()
require.NoError(t, err)
}
+115
View File
@@ -0,0 +1,115 @@
package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = listener.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = listener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
httpTransport.MaxIdleConns = 1
httpTransport.MaxIdleConnsPerHost = 1
httpTransport.IdleConnTimeout = time.Second
httpTransport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: httpTransport,
}
}
func newBoundDialContext(destinationAddress netip.Addr,
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil
}
}
func bindSourcePort(destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
bindAddr = netip.AddrFrom4([4]byte{})
} else {
bindAddr = netip.AddrFrom16([16]byte{})
}
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
netip.AddrPortFrom(bindAddr, 0)))
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()
return listener, sourceAddr, nil
}
+12
View File
@@ -0,0 +1,12 @@
package restrictednet
import (
"context"
"net/netip"
)
type Firewall interface {
AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error
}
@@ -0,0 +1,3 @@
package restrictednet
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
+50
View File
@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall)
// Package restrictednet is a generated GoMock package.
package restrictednet
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutputFromIPPortToIPPort mocks base method.
func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort.
func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5)
}
+177
View File
@@ -0,0 +1,177 @@
package restrictednet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"github.com/miekg/dns"
)
// ResolveName resolves the given host name to IP addresses using DoH servers,
// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers.
// The host must be a single well-formed domain name, without port or path.
func (c *Client) ResolveName(ctx context.Context, host string) (
resolvedAddresses []netip.Addr, err error,
) {
questionTypes := make([]uint16, 0, 2)
if c.ipv6Supported {
questionTypes = append(questionTypes, dns.TypeAAAA)
}
questionTypes = append(questionTypes, dns.TypeA)
var addresses []netip.Addr
errs := make([]error, 0, len(questionTypes))
for _, questionType := range questionTypes {
answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType)
if err != nil {
errs = append(errs, err)
continue
}
addresses = append(addresses, answerAddresses...)
}
switch {
case len(addresses) > 0:
return addresses, nil
case len(errs) == 0:
return nil, nil // no address found
default: // errors
return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...))
}
}
func (c *Client) resolveOneQuestionType(ctx context.Context,
host string, questionType uint16,
) (addresses []netip.Addr, err error) {
queryMessage := &dns.Msg{}
queryMessage.SetQuestion(dns.Fqdn(host), questionType)
queryWire, err := queryMessage.Pack()
if err != nil {
return nil, fmt.Errorf("packing DNS query: %w", err)
}
// Try every DoH server and every of each of their IP until we get a non-empty
// successful response.
errs := make([]error, 0)
for _, dohServer := range c.dohServers {
dohURL, err := url.Parse(dohServer.URL)
if err != nil {
errs = append(errs,
fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err))
continue
}
dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6))
if c.ipv6Supported {
// Prefer IPv6 addresses if IPv6 is supported
dohServerIPs = append(dohServerIPs, dohServer.IPv6...)
}
dohServerIPs = append(dohServerIPs, dohServer.IPv4...)
for _, dohServerIP := range dohServerIPs {
responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP)
switch {
case err != nil:
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w",
dohServer.URL, dohServerIP, err))
continue
case responseMessage.Rcode != dns.RcodeSuccess:
errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s",
dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode]))
continue
}
addresses := answersToNetipAddrs(responseMessage)
if len(addresses) == 0 {
continue
}
return addresses, nil
}
}
if len(errs) == 0 {
return nil, nil
}
return nil, fmt.Errorf("resolving %s %s: %w",
dns.TypeToString[questionType], host, errors.Join(errs...))
}
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerIP netip.Addr,
) (responseMessage *dns.Msg, err error) {
httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err)
}
defer func() {
closeErr := close()
if err == nil && closeErr != nil {
err = fmt.Errorf("cleaning up https connection: %w", closeErr)
}
}()
requestBody := bytes.NewReader(queryWire)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
request.Header.Set("Content-Type", "application/dns-message")
request.Header.Set("Accept", "application/dns-message")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
responseData, err := io.ReadAll(response.Body)
if err != nil {
_ = response.Body.Close()
return nil, fmt.Errorf("reading response body: %w", err)
}
err = response.Body.Close()
if err != nil {
return nil, fmt.Errorf("closing response body: %w", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response status code is %s, data: %s",
response.Status, responseData)
}
responseMessage = new(dns.Msg)
err = responseMessage.Unpack(responseData)
if err != nil {
return nil, fmt.Errorf("parsing DoH response: %w", err)
}
return responseMessage, nil
}
func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) {
if message == nil {
return nil
}
addresses = make([]netip.Addr, 0, len(message.Answer))
for _, answer := range message.Answer {
switch record := answer.(type) {
case *dns.A:
address, ok := netip.AddrFromSlice(record.A)
if ok {
addresses = append(addresses, address.Unmap())
}
case *dns.AAAA:
address, ok := netip.AddrFromSlice(record.AAAA)
if ok {
addresses = append(addresses, address)
}
}
}
return addresses
}
+82
View File
@@ -0,0 +1,82 @@
package restrictednet
import (
"net"
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func Test_answersToNetipAddrs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
message *dns.Msg
expected []netip.Addr
errorIsNil bool
}{
"nil_message": {
message: nil,
expected: nil,
errorIsNil: true,
},
"no_answers": {
message: &dns.Msg{},
expected: []netip.Addr{},
errorIsNil: true,
},
"a_record": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
errorIsNil: true,
},
"aaaa_record": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")},
errorIsNil: true,
},
"mixed_records": {
message: &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.IP{1, 1, 1, 1},
},
&dns.AAAA{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET},
AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88},
},
},
},
expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")},
errorIsNil: true,
},
}
for testName, testCase := range testCases {
testCase := testCase
t.Run(testName, func(t *testing.T) {
t.Parallel()
addresses := answersToNetipAddrs(testCase.message)
assert.Equal(t, testCase.expected, addresses)
})
}
}