context aware connectFD

This commit is contained in:
Quentin McGaw
2026-06-11 13:06:05 +00:00
parent 9af6aaff27
commit 70d80f7473
3 changed files with 54 additions and 20 deletions
+1 -16
View File
@@ -144,26 +144,11 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
connection net.Conn, err error,
) {
errCh := make(chan error)
go func() {
errCh <- connectFD(fd, destinationAddrPort)
}()
select {
case err = <-errCh:
err = connectFD(ctx, fd, destinationAddrPort)
if err != nil {
closeFD(fd)
return nil, fmt.Errorf("connecting socket: %w", err)
}
case <-ctx.Done():
err = ctx.Err()
closeFD(fd)
connectErr := <-errCh
if connectErr != nil {
err = fmt.Errorf("%w (%w)", connectErr, err)
}
return nil, fmt.Errorf("connecting socket: %w", err)
}
file := os.NewFile(uintptr(fd), "")
if file == nil {
+50 -2
View File
@@ -3,8 +3,11 @@
package restrictednet
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"golang.org/x/sys/unix"
)
@@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error {
return unix.Bind(fd, bindAddr)
}
func connectFD(fd int, destination netip.AddrPort) error {
return unix.Connect(fd, makeSockAddr(destination))
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
err := unix.Connect(fd, makeSockAddr(destination))
switch {
case err == nil:
return nil
case !errors.Is(err, unix.EINPROGRESS):
return err
}
for {
select {
case <-ctx.Done():
err = unix.Close(fd)
if err != nil {
return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err())
}
return ctx.Err()
default:
wset := &unix.FdSet{}
wset.Bits[fd/64] |= 1 << (uint(fd) % 64)
eset := &unix.FdSet{}
eset.Bits[fd/64] |= 1 << (uint(fd) % 64)
const selectTimeout = 50 * time.Millisecond
timeval := unix.NsecToTimeval(int64(selectTimeout))
// Wait for the FD to become writable or hit an error state
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue // Syscall interrupted, try again
}
return fmt.Errorf("select error: %w", err)
} else if n == 0 {
continue // no status change yet
}
// Check if the socket encountered an error
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
if err != nil {
return fmt.Errorf("getsockopt error: %w", err)
} else if n != 0 {
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
}
return nil
}
}
}
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
+2 -1
View File
@@ -3,6 +3,7 @@
package restrictednet
import (
"context"
"net/netip"
)
@@ -18,7 +19,7 @@ func bindFD(fd int, address netip.AddrPort) error {
panic("not implemented")
}
func connectFD(fd int, destination netip.AddrPort) error {
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
panic("not implemented")
}