From 70d80f7473f66a5b04ec7e88bf590267d271ecd5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 13:06:05 +0000 Subject: [PATCH] context aware connectFD --- internal/restrictednet/https.go | 19 ++--------- internal/restrictednet/unix.go | 52 +++++++++++++++++++++++++++++-- internal/restrictednet/windows.go | 3 +- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 209b5a9f..1ad6d891 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -144,24 +144,9 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( connection net.Conn, err error, ) { - errCh := make(chan error) - go func() { - errCh <- connectFD(fd, destinationAddrPort) - }() - - select { - case err = <-errCh: - if err != nil { - closeFD(fd) - return nil, fmt.Errorf("connecting socket: %w", err) - } - case <-ctx.Done(): - err = ctx.Err() + err = connectFD(ctx, fd, destinationAddrPort) + if err != nil { closeFD(fd) - connectErr := <-errCh - if connectErr != nil { - err = fmt.Errorf("%w (%w)", connectErr, err) - } return nil, fmt.Errorf("connecting socket: %w", err) } diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go index 968f8d30..387233cc 100644 --- a/internal/restrictednet/unix.go +++ b/internal/restrictednet/unix.go @@ -3,8 +3,11 @@ package restrictednet import ( + "context" + "errors" "fmt" "net/netip" + "time" "golang.org/x/sys/unix" ) @@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error { return unix.Bind(fd, bindAddr) } -func connectFD(fd int, destination netip.AddrPort) error { - return unix.Connect(fd, makeSockAddr(destination)) +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { + err := unix.Connect(fd, makeSockAddr(destination)) + switch { + case err == nil: + return nil + case !errors.Is(err, unix.EINPROGRESS): + return err + } + + for { + select { + case <-ctx.Done(): + err = unix.Close(fd) + if err != nil { + return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err()) + } + return ctx.Err() + default: + wset := &unix.FdSet{} + wset.Bits[fd/64] |= 1 << (uint(fd) % 64) + eset := &unix.FdSet{} + eset.Bits[fd/64] |= 1 << (uint(fd) % 64) + const selectTimeout = 50 * time.Millisecond + timeval := unix.NsecToTimeval(int64(selectTimeout)) + + // Wait for the FD to become writable or hit an error state + n, err := unix.Select(fd+1, nil, wset, eset, &timeval) + if err != nil { + if errors.Is(err, unix.EINTR) { + continue // Syscall interrupted, try again + } + return fmt.Errorf("select error: %w", err) + } else if n == 0 { + continue // no status change yet + } + + // Check if the socket encountered an error + n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR) + if err != nil { + return fmt.Errorf("getsockopt error: %w", err) + } else if n != 0 { + return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n)) + } + + return nil + } + } } func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { diff --git a/internal/restrictednet/windows.go b/internal/restrictednet/windows.go index e1b88453..454fc2c6 100644 --- a/internal/restrictednet/windows.go +++ b/internal/restrictednet/windows.go @@ -3,6 +3,7 @@ package restrictednet import ( + "context" "net/netip" ) @@ -18,7 +19,7 @@ func bindFD(fd int, address netip.AddrPort) error { panic("not implemented") } -func connectFD(fd int, destination netip.AddrPort) error { +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { panic("not implemented") }