imporatnt fix 1

This commit is contained in:
Quentin McGaw
2026-06-05 04:46:16 +00:00
parent fad8c9889a
commit a9a36644ec
4 changed files with 41 additions and 34 deletions
+1 -1
View File
@@ -47,7 +47,7 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
selectedIP := resolvedIPs[0] selectedIP := resolvedIPs[0]
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP) httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err) return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
} }
+19 -5
View File
@@ -2,6 +2,7 @@ package restrictednet
import ( import (
"context" "context"
"net"
"net/netip" "net/netip"
"testing" "testing"
@@ -34,15 +35,28 @@ func (m listenAddrPortMatcher) String() string {
func Test_Client_OpenHTTPS(t *testing.T) { func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel() t.Parallel()
ctx := t.Context()
netConfig := net.ListenConfig{}
listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443")
require.NoError(t, err)
t.Cleanup(func() {
_ = listener.Close()
})
go func() {
connection, acceptErr := listener.Accept()
if acceptErr == nil {
_ = connection.Close()
}
}()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl) firewall := NewMockFirewall(ctrl)
destination := netip.MustParseAddrPort("1.2.3.4:443") destination := netip.MustParseAddrPort("127.0.0.1:443")
backgroundContext := context.Background()
sourceMatcher := listenAddrPortMatcher{} sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false, ctx, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(_ context.Context, ).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool, _, _ string, source, _ netip.AddrPort, _ bool,
) error { ) error {
@@ -50,7 +64,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
return nil return nil
}) })
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true, ctx, "tcp", "eth0", sourceMatcher, destination, true,
) )
const ipv6Supported = false const ipv6Supported = false
@@ -58,7 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
require.NoError(t, err) require.NoError(t, err)
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4")) httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1"))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, httpClient) require.NotNil(t, httpClient)
require.NotNil(t, cleanup) require.NotNil(t, cleanup)
+20 -27
View File
@@ -13,9 +13,9 @@ import (
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections. // The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) { ) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP) connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err) return nil, nil, fmt.Errorf("binding source port: %w", err)
} }
@@ -24,15 +24,14 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove) sourceAddrPort, destinationAddrPort, remove)
if err != nil { if err != nil {
_ = listener.Close() _ = connection.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
} }
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort) httpClient = newHTTPSClient(destinationTLSName, connection)
cleanup = func() error { cleanup = func() error {
var errs []error var errs []error
httpClient.CloseIdleConnections() httpClient.CloseIdleConnections()
@@ -42,9 +41,9 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
} }
err = listener.Close() err = connection.Close()
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err)) errs = append(errs, fmt.Errorf("closing connection: %w", err))
} }
if len(errs) > 0 { if len(errs) > 0 {
return errors.Join(errs...) return errors.Join(errs...)
@@ -55,7 +54,7 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
} }
func newHTTPSClient(destinationTLSName string, func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort, connection net.Conn,
) *http.Client { ) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil httpTransport.Proxy = nil
@@ -66,7 +65,7 @@ func newHTTPSClient(destinationTLSName string,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName, ServerName: destinationTLSName,
} }
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress) httpTransport.DialContext = newConnectionDialContext(connection)
const timeout = 5 * time.Second const timeout = 5 * time.Second
return &http.Client{ return &http.Client{
@@ -75,25 +74,14 @@ func newHTTPSClient(destinationTLSName string,
} }
} }
func newBoundDialContext(destinationAddress netip.Addr, func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) {
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
return func(ctx context.Context, network, _ string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil return connection, nil
} }
} }
func bindSourcePort(destinationIP netip.Addr) ( func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error, connection net.Conn, sourceAddr netip.AddrPort, err error,
) { ) {
var bindAddr netip.Addr var bindAddr netip.Addr
if destinationIP.Is4() { if destinationIP.Is4() {
@@ -102,14 +90,19 @@ func bindSourcePort(destinationIP netip.Addr) (
bindAddr = netip.AddrFrom16([16]byte{}) bindAddr = netip.AddrFrom16([16]byte{})
} }
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort( const httpsPort = 443
netip.AddrPortFrom(bindAddr, 0))) destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
dialer := &net.Dialer{
Timeout: time.Second,
LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)),
}
connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String())
if err != nil { if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
} }
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort() sourceAddr = tcpAddr.AddrPort()
return listener, sourceAddr, nil return connection, sourceAddr, nil
} }
+1 -1
View File
@@ -106,7 +106,7 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
func (c *Client) doHQuery(ctx context.Context, queryWire []byte, func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerIP netip.Addr, dohURL *url.URL, dohServerIP netip.Addr,
) (responseMessage *dns.Msg, err error) { ) (responseMessage *dns.Msg, err error) {
httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err) return nil, fmt.Errorf("opening https connection: %w", err)
} }