pr review fixes

This commit is contained in:
Quentin McGaw
2026-06-05 15:25:44 +00:00
parent b48ba8cb0a
commit 2d2c371303
4 changed files with 16 additions and 11 deletions
+9 -5
View File
@@ -2,6 +2,7 @@ package restrictednet
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -48,12 +49,15 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
return nil, nil, fmt.Errorf("no IP address found for name %q", domain) return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
} }
selectedIP := resolvedIPs[0] errs := make([]error, 0, len(resolvedIPs))
for _, ip := range resolvedIPs {
httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP) httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err) errs = append(errs, fmt.Errorf("for %s: %w", ip, err))
continue
}
return httpClient, cleanup, nil
} }
return httpClient, cleanup, nil return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...))
} }
+3 -2
View File
@@ -47,7 +47,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
var errs []error var errs []error
httpClient.CloseIdleConnections() httpClient.CloseIdleConnections()
const remove = true const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove) sourceAddrPort, destinationAddrPort, remove)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
@@ -76,7 +76,8 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client
ServerName: destinationTLSName, ServerName: destinationTLSName,
} }
expectedAddress := net.JoinHostPort(destinationTLSName, "443") _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String())
expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort)
httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) {
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
@@ -65,7 +65,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
return nil return nil
}) })
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
ctx, "tcp", "eth0", sourceMatcher, destination, true, context.Background(), "tcp", "eth0", sourceMatcher, destination, true,
) )
const ipv6Supported = false const ipv6Supported = false
+2 -2
View File
@@ -142,8 +142,8 @@ func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
} }
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response status code is %s, data: %s", return nil, fmt.Errorf("response status code is %s (data length %d)",
response.Status, responseData) response.Status, len(responseData))
} }
responseMessage = new(dns.Msg) responseMessage = new(dns.Msg)