PR feedback fixes

This commit is contained in:
Quentin McGaw
2026-06-09 21:11:15 +00:00
parent 29186feccc
commit 69b4e5c584
3 changed files with 20 additions and 8 deletions
+5 -1
View File
@@ -40,7 +40,11 @@ func New(settings Settings) *Client {
} }
} }
func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) ( // OpenHTTPSByHostname opens an https connection through the firewall,
// valid for up to one second, to the hostname which in the format `host:port`.
// It first resolves the domain in hostname using DNS over HTTPS and then opens
// the restricted HTTPS connection to the resolved IP.
func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) (
httpClient *http.Client, cleanup func() error, err error, httpClient *http.Client, cleanup func() error, err error,
) { ) {
host, portStr, err := net.SplitHostPort(hostname) host, portStr, err := net.SplitHostPort(hostname)
+11 -3
View File
@@ -45,12 +45,12 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
cleanup = func() error { cleanup = func() error {
var errs []error var errs []error
httpClient.CloseIdleConnections() httpClient.CloseIdleConnections()
err = connection.Close() err := connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing connection: %w", err)) errs = append(errs, fmt.Errorf("closing connection: %w", err))
} }
const remove = true const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove) sourceAddrPort, destinationAddrPort, remove)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
@@ -85,9 +85,17 @@ func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
} }
func makeDial(connection net.Conn, tlsName string) dialFunc { func makeDial(connection net.Conn, tlsName string) dialFunc {
_, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) _, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String())
if err != nil {
panic(err) // connection remote address should always be in the form "host:port"
}
expectedAddress := net.JoinHostPort(tlsName, destinationPort) expectedAddress := net.JoinHostPort(tlsName, destinationPort)
used := false
return func(_ context.Context, network, address string) (net.Conn, error) { return func(_ context.Context, network, address string) (net.Conn, error) {
if used {
return nil, errors.New("dial function called more than once")
}
used = true
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
default: default:
+4 -4
View File
@@ -77,7 +77,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
}) })
firewall.EXPECT().AcceptOutputFromIPPortToIPPort( firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true,
) ).Return(nil)
const ipv6Supported = false const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()} upstreamResolvers := []provider.Provider{provider.Google()}
@@ -98,10 +98,10 @@ func Test_Client_OpenHTTPS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
response, err := httpClient.Do(request) response, err := httpClient.Do(request)
t.Cleanup(func() {
response.Body.Close()
})
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() {
_ = response.Body.Close()
})
assert.Equal(t, http.StatusOK, response.StatusCode) assert.Equal(t, http.StatusOK, response.StatusCode)