mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 21:37:31 +02:00
context aware connectFD
This commit is contained in:
@@ -3,8 +3,11 @@
|
||||
package restrictednet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error {
|
||||
return unix.Bind(fd, bindAddr)
|
||||
}
|
||||
|
||||
func connectFD(fd int, destination netip.AddrPort) error {
|
||||
return unix.Connect(fd, makeSockAddr(destination))
|
||||
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||
err := unix.Connect(fd, makeSockAddr(destination))
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case !errors.Is(err, unix.EINPROGRESS):
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = unix.Close(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err())
|
||||
}
|
||||
return ctx.Err()
|
||||
default:
|
||||
wset := &unix.FdSet{}
|
||||
wset.Bits[fd/64] |= 1 << (uint(fd) % 64)
|
||||
eset := &unix.FdSet{}
|
||||
eset.Bits[fd/64] |= 1 << (uint(fd) % 64)
|
||||
const selectTimeout = 50 * time.Millisecond
|
||||
timeval := unix.NsecToTimeval(int64(selectTimeout))
|
||||
|
||||
// Wait for the FD to become writable or hit an error state
|
||||
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue // Syscall interrupted, try again
|
||||
}
|
||||
return fmt.Errorf("select error: %w", err)
|
||||
} else if n == 0 {
|
||||
continue // no status change yet
|
||||
}
|
||||
|
||||
// Check if the socket encountered an error
|
||||
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getsockopt error: %w", err)
|
||||
} else if n != 0 {
|
||||
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||
|
||||
Reference in New Issue
Block a user