From a9a36644ecdf3c2ffef0e1a31a1559e9a9b1c41e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:46:16 +0000 Subject: [PATCH] imporatnt fix 1 --- internal/restrictednet/client.go | 2 +- internal/restrictednet/client_test.go | 24 +++++++++++--- internal/restrictednet/https.go | 47 ++++++++++++--------------- internal/restrictednet/resolve.go | 2 +- 4 files changed, 41 insertions(+), 34 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index d8812a3f..292f3e3d 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -47,7 +47,7 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( selectedIP := resolvedIPs[0] - httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP) + httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP) if err != nil { return nil, nil, fmt.Errorf("opening HTTPS: %w", err) } diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/client_test.go index b3f5ba8d..ff10e822 100644 --- a/internal/restrictednet/client_test.go +++ b/internal/restrictednet/client_test.go @@ -2,6 +2,7 @@ package restrictednet import ( "context" + "net" "net/netip" "testing" @@ -34,15 +35,28 @@ func (m listenAddrPortMatcher) String() string { func Test_Client_OpenHTTPS(t *testing.T) { t.Parallel() + ctx := t.Context() + + netConfig := net.ListenConfig{} + listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443") + require.NoError(t, err) + t.Cleanup(func() { + _ = listener.Close() + }) + go func() { + connection, acceptErr := listener.Accept() + if acceptErr == nil { + _ = connection.Close() + } + }() ctrl := gomock.NewController(t) firewall := NewMockFirewall(ctrl) - destination := netip.MustParseAddrPort("1.2.3.4:443") - backgroundContext := context.Background() + destination := netip.MustParseAddrPort("127.0.0.1:443") sourceMatcher := listenAddrPortMatcher{} firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - backgroundContext, "tcp", "eth0", sourceMatcher, destination, false, + ctx, "tcp", "eth0", sourceMatcher, destination, false, ).DoAndReturn(func(_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, ) error { @@ -50,7 +64,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - backgroundContext, "tcp", "eth0", sourceMatcher, destination, true, + ctx, "tcp", "eth0", sourceMatcher, destination, true, ) const ipv6Supported = false @@ -58,7 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) require.NoError(t, err) - httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4")) + httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) require.NoError(t, err) require.NotNil(t, httpClient) require.NotNil(t, cleanup) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 462f69c2..767d95e2 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -13,9 +13,9 @@ import ( // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. -func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, +func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, ) (httpClient *http.Client, cleanup func() error, err error) { - listener, sourceAddrPort, err := bindSourcePort(destinationIP) + connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } @@ -24,15 +24,14 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) const remove = false - ctx := context.Background() // it's a quick firewall change, worth not passing a context err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { - _ = listener.Close() + _ = connection.Close() return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } - httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort) + httpClient = newHTTPSClient(destinationTLSName, connection) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -42,9 +41,9 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) } - err = listener.Close() + err = connection.Close() if err != nil { - errs = append(errs, fmt.Errorf("closing listener: %w", err)) + errs = append(errs, fmt.Errorf("closing connection: %w", err)) } if len(errs) > 0 { return errors.Join(errs...) @@ -55,7 +54,7 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, } func newHTTPSClient(destinationTLSName string, - destinationIP netip.Addr, sourceAddress netip.AddrPort, + connection net.Conn, ) *http.Client { httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert httpTransport.Proxy = nil @@ -66,7 +65,7 @@ func newHTTPSClient(destinationTLSName string, MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } - httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress) + httpTransport.DialContext = newConnectionDialContext(connection) const timeout = 5 * time.Second return &http.Client{ @@ -75,25 +74,14 @@ func newHTTPSClient(destinationTLSName string, } } -func newBoundDialContext(destinationAddress netip.Addr, - sourceAddress netip.AddrPort, -) func(ctx context.Context, network, _ string) (net.Conn, error) { - const httpsPort = 443 - destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String() +func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) { - const timeout = 2 * time.Second - dialer := &net.Dialer{Timeout: timeout} - dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress) - connection, err := dialer.DialContext(ctx, network, destinationAddrPort) - if err != nil { - return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err) - } return connection, nil } } -func bindSourcePort(destinationIP netip.Addr) ( - listener net.Listener, sourceAddr netip.AddrPort, err error, +func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) ( + connection net.Conn, sourceAddr netip.AddrPort, err error, ) { var bindAddr netip.Addr if destinationIP.Is4() { @@ -102,14 +90,19 @@ func bindSourcePort(destinationIP netip.Addr) ( bindAddr = netip.AddrFrom16([16]byte{}) } - listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort( - netip.AddrPortFrom(bindAddr, 0))) + const httpsPort = 443 + destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) + dialer := &net.Dialer{ + Timeout: time.Second, + LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)), + } + connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String()) if err != nil { return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) } - tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert + tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert sourceAddr = tcpAddr.AddrPort() - return listener, sourceAddr, nil + return connection, sourceAddr, nil } diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index b5b789c7..e14e5c9b 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -106,7 +106,7 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, func (c *Client) doHQuery(ctx context.Context, queryWire []byte, dohURL *url.URL, dohServerIP netip.Addr, ) (responseMessage *dns.Msg, err error) { - httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) + httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP) if err != nil { return nil, fmt.Errorf("opening https connection: %w", err) }