From 820689cc238b48062fde13827d2ec4c2cf06d7aa Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:46:20 +0000 Subject: [PATCH] imporatnt fix 2 --- internal/restrictednet/https.go | 135 ++++++++++++++++++++++-------- internal/restrictednet/unix.go | 64 ++++++++++++++ internal/restrictednet/windows.go | 27 ++++++ 3 files changed, 193 insertions(+), 33 deletions(-) create mode 100644 internal/restrictednet/unix.go create mode 100644 internal/restrictednet/windows.go diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 767d95e2..9444ab7a 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -8,14 +8,18 @@ import ( "net" "net/http" "net/netip" + "os" "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/qdm12/gluetun/internal/pmtud/constants" ) // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, ) (httpClient *http.Client, cleanup func() error, err error) { - connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP) + fd, sourceAddrPort, err := bindSourceConnection(destinationIP) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } @@ -27,10 +31,18 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { - _ = connection.Close() + closeFD(fd) return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } + connection, err := connectSourceConnection(fd, destinationAddrPort) + if err != nil { + const remove = true + _ = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + return nil, nil, fmt.Errorf("connecting source socket: %w", err) + } + httpClient = newHTTPSClient(destinationTLSName, connection) cleanup = func() error { var errs []error @@ -53,9 +65,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(destinationTLSName string, - connection net.Conn, -) *http.Client { +func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client { httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert httpTransport.Proxy = nil httpTransport.MaxIdleConns = 1 @@ -65,7 +75,9 @@ func newHTTPSClient(destinationTLSName string, MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } - httpTransport.DialContext = newConnectionDialContext(connection) + httpTransport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return connection, nil + } const timeout = 5 * time.Second return &http.Client{ @@ -74,35 +86,92 @@ func newHTTPSClient(destinationTLSName string, } } -func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) { - return func(ctx context.Context, network, _ string) (net.Conn, error) { - return connection, nil - } -} - -func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) ( - connection net.Conn, sourceAddr netip.AddrPort, err error, -) { - var bindAddr netip.Addr - if destinationIP.Is4() { - bindAddr = netip.AddrFrom4([4]byte{}) - } else { - bindAddr = netip.AddrFrom16([16]byte{}) - } - - const httpsPort = 443 - destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) - dialer := &net.Dialer{ - Timeout: time.Second, - LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)), - } - connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String()) +func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { + sourceIP, err := sourceIPForDestination(destinationIP) if err != nil { - return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) + return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err) } - tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert - sourceAddr = tcpAddr.AddrPort() + family := constants.AF_INET + if sourceIP.Is6() { + family = constants.AF_INET6 + } - return connection, sourceAddr, nil + fd, err = newTCPSockStream(family) + if err != nil { + return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err) + } + + bindAddrPort := netip.AddrPortFrom(sourceIP, 0) + err = bindFD(fd, bindAddrPort) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err) + } + + sourceAddr, err = fdToSourceAddr(fd) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err) + } + + return fd, sourceAddr, nil +} + +func connectSourceConnection(fd int, destinationAddrPort netip.AddrPort) (connection net.Conn, err error) { + err = connectFD(fd, destinationAddrPort) + if err != nil { + closeFD(fd) + return nil, fmt.Errorf("connecting socket: %w", err) + } + + file := os.NewFile(uintptr(fd), "") + if file == nil { + closeFD(fd) + return nil, fmt.Errorf("creating socket file") + } + defer file.Close() + + connection, err = net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("wrapping socket connection: %w", err) + } + + return connection, nil +} + +func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return netip.Addr{}, err + } + defer conn.Close() + + family := uint8(constants.AF_INET) + if destinationIP.Is6() { + family = constants.AF_INET6 + } + + requestMessage := &rtnetlink.RouteMessage{ + Family: family, + Attributes: rtnetlink.RouteAttributes{ + Dst: destinationIP.AsSlice(), + }, + } + messages, err := conn.Route.Get(requestMessage) + if err != nil { + return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err) + } + + for _, message := range messages { + if message.Attributes.Src == nil { + continue + } + if message.Attributes.Src.To4() == nil { + return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil + } + return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil + } + + return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP) } diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go new file mode 100644 index 00000000..76895943 --- /dev/null +++ b/internal/restrictednet/unix.go @@ -0,0 +1,64 @@ +//go:build unix + +package restrictednet + +import ( + "fmt" + "net/netip" + + "golang.org/x/sys/unix" +) + +func closeFD(fd int) { + unix.Close(fd) +} + +func newTCPSockStream(family int) (fd int, err error) { + return unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP) +} + +func bindFD(fd int, address netip.AddrPort) error { + bindAddr := makeSockAddr(address) + return unix.Bind(fd, bindAddr) +} + +func connectFD(fd int, destination netip.AddrPort) error { + return unix.Connect(fd, makeSockAddr(destination)) +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + sockAddr, err := unix.Getsockname(fd) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err) + } + + sourceAddrPort, err = sockAddrToAddrPort(sockAddr) + if err != nil { + return netip.AddrPort{}, err + } + return sourceAddrPort, nil +} + +func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr { + if addressPort.Addr().Is4() { + return &unix.SockaddrInet4{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As4(), + } + } + return &unix.SockaddrInet6{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As16(), + } +} + +func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) { + switch typedSockAddr := sockAddr.(type) { + case *unix.SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + default: + return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr) + } +} diff --git a/internal/restrictednet/windows.go b/internal/restrictednet/windows.go new file mode 100644 index 00000000..e1b88453 --- /dev/null +++ b/internal/restrictednet/windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package restrictednet + +import ( + "net/netip" +) + +func closeFD(fd int) { + panic("not implemented") +} + +func newTCPSockStream(family int) (fd int, err error) { + panic("not implemented") +} + +func bindFD(fd int, address netip.AddrPort) error { + panic("not implemented") +} + +func connectFD(fd int, destination netip.AddrPort) error { + panic("not implemented") +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + panic("not implemented") +}