context aware connectFD

This commit is contained in:
Quentin McGaw
2026-06-11 13:06:05 +00:00
parent 9af6aaff27
commit 70d80f7473
3 changed files with 54 additions and 20 deletions
+2 -17
View File
@@ -144,24 +144,9 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
connection net.Conn, err error, connection net.Conn, err error,
) { ) {
errCh := make(chan error) err = connectFD(ctx, fd, destinationAddrPort)
go func() { if err != nil {
errCh <- connectFD(fd, destinationAddrPort)
}()
select {
case err = <-errCh:
if err != nil {
closeFD(fd)
return nil, fmt.Errorf("connecting socket: %w", err)
}
case <-ctx.Done():
err = ctx.Err()
closeFD(fd) closeFD(fd)
connectErr := <-errCh
if connectErr != nil {
err = fmt.Errorf("%w (%w)", connectErr, err)
}
return nil, fmt.Errorf("connecting socket: %w", err) return nil, fmt.Errorf("connecting socket: %w", err)
} }
+50 -2
View File
@@ -3,8 +3,11 @@
package restrictednet package restrictednet
import ( import (
"context"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"time"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error {
return unix.Bind(fd, bindAddr) return unix.Bind(fd, bindAddr)
} }
func connectFD(fd int, destination netip.AddrPort) error { func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
return unix.Connect(fd, makeSockAddr(destination)) err := unix.Connect(fd, makeSockAddr(destination))
switch {
case err == nil:
return nil
case !errors.Is(err, unix.EINPROGRESS):
return err
}
for {
select {
case <-ctx.Done():
err = unix.Close(fd)
if err != nil {
return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err())
}
return ctx.Err()
default:
wset := &unix.FdSet{}
wset.Bits[fd/64] |= 1 << (uint(fd) % 64)
eset := &unix.FdSet{}
eset.Bits[fd/64] |= 1 << (uint(fd) % 64)
const selectTimeout = 50 * time.Millisecond
timeval := unix.NsecToTimeval(int64(selectTimeout))
// Wait for the FD to become writable or hit an error state
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue // Syscall interrupted, try again
}
return fmt.Errorf("select error: %w", err)
} else if n == 0 {
continue // no status change yet
}
// Check if the socket encountered an error
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
if err != nil {
return fmt.Errorf("getsockopt error: %w", err)
} else if n != 0 {
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
}
return nil
}
}
} }
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
+2 -1
View File
@@ -3,6 +3,7 @@
package restrictednet package restrictednet
import ( import (
"context"
"net/netip" "net/netip"
) )
@@ -18,7 +19,7 @@ func bindFD(fd int, address netip.AddrPort) error {
panic("not implemented") panic("not implemented")
} }
func connectFD(fd int, destination netip.AddrPort) error { func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
panic("not implemented") panic("not implemented")
} }