diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 82091a75..fb070e8a 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -40,7 +40,11 @@ func New(settings Settings) *Client { } } -func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) ( +// OpenHTTPSByHostname opens an https connection through the firewall, +// valid for up to one second, to the hostname which in the format `host:port`. +// It first resolves the domain in hostname using DNS over HTTPS and then opens +// the restricted HTTPS connection to the resolved IP. +func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) ( httpClient *http.Client, cleanup func() error, err error, ) { host, portStr, err := net.SplitHostPort(hostname) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index ea08c6c8..6912eff0 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -45,12 +45,12 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti cleanup = func() error { var errs []error httpClient.CloseIdleConnections() - err = connection.Close() + err := connection.Close() if err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("closing connection: %w", err)) } const remove = true - err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, + err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) @@ -85,9 +85,17 @@ func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { } func makeDial(connection net.Conn, tlsName string) dialFunc { - _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) + _, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String()) + if err != nil { + panic(err) // connection remote address should always be in the form "host:port" + } expectedAddress := net.JoinHostPort(tlsName, destinationPort) + used := false return func(_ context.Context, network, address string) (net.Conn, error) { + if used { + return nil, errors.New("dial function called more than once") + } + used = true switch network { case "tcp", "tcp4", "tcp6": default: diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index b488f505..a977b5ff 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -77,7 +77,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, - ) + ).Return(nil) const ipv6Supported = false upstreamResolvers := []provider.Provider{provider.Google()} @@ -98,10 +98,10 @@ func Test_Client_OpenHTTPS(t *testing.T) { require.NoError(t, err) response, err := httpClient.Do(request) - t.Cleanup(func() { - response.Body.Close() - }) require.NoError(t, err) + t.Cleanup(func() { + _ = response.Body.Close() + }) assert.Equal(t, http.StatusOK, response.StatusCode)