Change tests to be more integration oriented

This commit is contained in:
Quentin McGaw
2026-06-09 14:04:32 +00:00
parent dd07205b85
commit b5366b9e44
7 changed files with 128 additions and 566 deletions
+28 -25
View File
@@ -17,15 +17,13 @@ import (
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort,
) (httpClient *http.Client, cleanup func() error, err error) {
fd, sourceAddrPort, err := bindSourceConnection(destinationIP)
fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr())
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort)
const remove = false
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
@@ -42,7 +40,8 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
return nil, nil, fmt.Errorf("connecting source socket: %w", err)
}
httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection)
dial := makeDial(connection, destinationTLSName)
httpClient = newHTTPSClient(destinationTLSName, dial)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
@@ -53,7 +52,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = connection.Close()
if err != nil {
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing connection: %w", err))
}
if len(errs) > 0 {
@@ -64,21 +63,31 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti
return httpClient, cleanup, nil
}
func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client {
transport := baseTransport.Clone()
transport.Proxy = nil
transport.MaxIdleConns = 1
transport.MaxIdleConnsPerHost = 1
transport.MaxConnsPerHost = 1
transport.IdleConnTimeout = time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)
func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client {
const timeout = 5 * time.Second
transport := &http.Transport{
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
IdleConnTimeout: time.Second,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
},
DialContext: dial,
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func makeDial(connection net.Conn, tlsName string) dialFunc {
_, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String())
expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort)
transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) {
expectedAddress := net.JoinHostPort(tlsName, destinationPort)
return func(_ context.Context, network, address string) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
@@ -89,12 +98,6 @@ func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, co
}
return connection, nil
}
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) {