From 8da913d7c6ae9faded223bc7606d17327b830a62 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 15:35:28 +0000 Subject: [PATCH] context aware connectSourceConnection --- internal/restrictednet/https.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 1bb5bb48..209e68f0 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -34,7 +34,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } - connection, err := connectSourceConnection(fd, destinationAddrPort) + connection, err := connectSourceConnection(ctx, fd, destinationAddrPort) if err != nil { const remove = true _ = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, @@ -129,10 +129,27 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad return fd, sourceAddr, nil } -func connectSourceConnection(fd int, destinationAddrPort netip.AddrPort) (connection net.Conn, err error) { - err = connectFD(fd, destinationAddrPort) - if err != nil { +func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( + connection net.Conn, err error, +) { + errCh := make(chan error) + go func() { + errCh <- connectFD(fd, destinationAddrPort) + }() + + select { + case err = <-errCh: + if err != nil { + closeFD(fd) + return nil, fmt.Errorf("connecting socket: %w", err) + } + case <-ctx.Done(): + err = ctx.Err() closeFD(fd) + connectErr := <-errCh + if connectErr != nil { + err = fmt.Errorf("%w (%w)", connectErr, err) + } return nil, fmt.Errorf("connecting socket: %w", err) }