From 2d2c3713032c560a1d675b56e0bdcc22fb93633d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 15:25:44 +0000 Subject: [PATCH] pr review fixes --- internal/restrictednet/client.go | 16 ++++++++++------ internal/restrictednet/https.go | 5 +++-- .../{client_test.go => https_test.go} | 2 +- internal/restrictednet/resolve.go | 4 ++-- 4 files changed, 16 insertions(+), 11 deletions(-) rename internal/restrictednet/{client_test.go => https_test.go} (96%) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 9e20b939..cdcd9472 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -2,6 +2,7 @@ package restrictednet import ( "context" + "errors" "fmt" "net/http" @@ -48,12 +49,15 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( return nil, nil, fmt.Errorf("no IP address found for name %q", domain) } - selectedIP := resolvedIPs[0] - - httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP) - if err != nil { - return nil, nil, fmt.Errorf("opening HTTPS: %w", err) + errs := make([]error, 0, len(resolvedIPs)) + for _, ip := range resolvedIPs { + httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip) + if err != nil { + errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) + continue + } + return httpClient, cleanup, nil } - return httpClient, cleanup, nil + return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...)) } diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index f3b71a43..1bb5bb48 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -47,7 +47,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti var errs []error httpClient.CloseIdleConnections() const remove = true - err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) @@ -76,7 +76,8 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client ServerName: destinationTLSName, } - expectedAddress := net.JoinHostPort(destinationTLSName, "443") + _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) + expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/https_test.go similarity index 96% rename from internal/restrictednet/client_test.go rename to internal/restrictednet/https_test.go index 65504f62..7db81e60 100644 --- a/internal/restrictednet/client_test.go +++ b/internal/restrictednet/https_test.go @@ -65,7 +65,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - ctx, "tcp", "eth0", sourceMatcher, destination, true, + context.Background(), "tcp", "eth0", sourceMatcher, destination, true, ) const ipv6Supported = false diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index e14e5c9b..aa1c3e64 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -142,8 +142,8 @@ func (c *Client) doHQuery(ctx context.Context, queryWire []byte, } if response.StatusCode != http.StatusOK { - return nil, fmt.Errorf("response status code is %s, data: %s", - response.Status, responseData) + return nil, fmt.Errorf("response status code is %s (data length %d)", + response.Status, len(responseData)) } responseMessage = new(dns.Msg)