mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 21:37:31 +02:00
context aware connectFD
This commit is contained in:
@@ -144,26 +144,11 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad
|
|||||||
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
|
func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) (
|
||||||
connection net.Conn, err error,
|
connection net.Conn, err error,
|
||||||
) {
|
) {
|
||||||
errCh := make(chan error)
|
err = connectFD(ctx, fd, destinationAddrPort)
|
||||||
go func() {
|
|
||||||
errCh <- connectFD(fd, destinationAddrPort)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-errCh:
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeFD(fd)
|
closeFD(fd)
|
||||||
return nil, fmt.Errorf("connecting socket: %w", err)
|
return nil, fmt.Errorf("connecting socket: %w", err)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
|
||||||
err = ctx.Err()
|
|
||||||
closeFD(fd)
|
|
||||||
connectErr := <-errCh
|
|
||||||
if connectErr != nil {
|
|
||||||
err = fmt.Errorf("%w (%w)", connectErr, err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("connecting socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "")
|
file := os.NewFile(uintptr(fd), "")
|
||||||
if file == nil {
|
if file == nil {
|
||||||
|
|||||||
@@ -3,8 +3,11 @@
|
|||||||
package restrictednet
|
package restrictednet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
@@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error {
|
|||||||
return unix.Bind(fd, bindAddr)
|
return unix.Bind(fd, bindAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectFD(fd int, destination netip.AddrPort) error {
|
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||||
return unix.Connect(fd, makeSockAddr(destination))
|
err := unix.Connect(fd, makeSockAddr(destination))
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return nil
|
||||||
|
case !errors.Is(err, unix.EINPROGRESS):
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = unix.Close(fd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err())
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
wset := &unix.FdSet{}
|
||||||
|
wset.Bits[fd/64] |= 1 << (uint(fd) % 64)
|
||||||
|
eset := &unix.FdSet{}
|
||||||
|
eset.Bits[fd/64] |= 1 << (uint(fd) % 64)
|
||||||
|
const selectTimeout = 50 * time.Millisecond
|
||||||
|
timeval := unix.NsecToTimeval(int64(selectTimeout))
|
||||||
|
|
||||||
|
// Wait for the FD to become writable or hit an error state
|
||||||
|
n, err := unix.Select(fd+1, nil, wset, eset, &timeval)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EINTR) {
|
||||||
|
continue // Syscall interrupted, try again
|
||||||
|
}
|
||||||
|
return fmt.Errorf("select error: %w", err)
|
||||||
|
} else if n == 0 {
|
||||||
|
continue // no status change yet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the socket encountered an error
|
||||||
|
n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getsockopt error: %w", err)
|
||||||
|
} else if n != 0 {
|
||||||
|
return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package restrictednet
|
package restrictednet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,7 +19,7 @@ func bindFD(fd int, address netip.AddrPort) error {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectFD(fd int, destination netip.AddrPort) error {
|
func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user