Files
gluetun/internal/restrictednet/https.go
T
Quentin McGaw a9a36644ec imporatnt fix 1
2026-06-05 04:46:16 +00:00

109 lines
3.2 KiB
Go

package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = connection.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
httpClient = newHTTPSClient(destinationTLSName, connection)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = connection.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing connection: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
func newHTTPSClient(destinationTLSName string,
connection net.Conn,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
httpTransport.MaxIdleConns = 1
httpTransport.MaxIdleConnsPerHost = 1
httpTransport.IdleConnTimeout = time.Second
httpTransport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newConnectionDialContext(connection)
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: httpTransport,
}
}
func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
return connection, nil
}
}
func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) (
connection net.Conn, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
bindAddr = netip.AddrFrom4([4]byte{})
} else {
bindAddr = netip.AddrFrom16([16]byte{})
}
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
dialer := &net.Dialer{
Timeout: time.Second,
LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)),
}
connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String())
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()
return connection, sourceAddr, nil
}