This commit is contained in:
Quentin McGaw
2026-06-05 03:56:25 +00:00
parent ff6e45fae0
commit aa781c6cc5
12 changed files with 599 additions and 0 deletions
+115
View File
@@ -0,0 +1,115 @@
package restrictednet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"
)
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = listener.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = listener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
return httpClient, cleanup, nil
}
func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
httpTransport.MaxIdleConns = 1
httpTransport.MaxIdleConnsPerHost = 1
httpTransport.IdleConnTimeout = time.Second
httpTransport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: httpTransport,
}
}
func newBoundDialContext(destinationAddress netip.Addr,
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil
}
}
func bindSourcePort(destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
bindAddr = netip.AddrFrom4([4]byte{})
} else {
bindAddr = netip.AddrFrom16([16]byte{})
}
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
netip.AddrPortFrom(bindAddr, 0)))
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()
return listener, sourceAddr, nil
}