imporatnt fix 1

This commit is contained in:
Quentin McGaw
2026-06-05 04:46:16 +00:00
parent fad8c9889a
commit a9a36644ec
4 changed files with 41 additions and 34 deletions
+1 -1
View File
@@ -47,7 +47,7 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
selectedIP := resolvedIPs[0]
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP)
if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
}
+19 -5
View File
@@ -2,6 +2,7 @@ package restrictednet
import (
"context"
"net"
"net/netip"
"testing"
@@ -34,15 +35,28 @@ func (m listenAddrPortMatcher) String() string {
func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()
ctx := t.Context()
netConfig := net.ListenConfig{}
listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443")
require.NoError(t, err)
t.Cleanup(func() {
_ = listener.Close()
})
go func() {
connection, acceptErr := listener.Accept()
if acceptErr == nil {
_ = connection.Close()
}
}()
ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)
destination := netip.MustParseAddrPort("1.2.3.4:443")
backgroundContext := context.Background()
destination := netip.MustParseAddrPort("127.0.0.1:443")
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
ctx, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
@@ -50,7 +64,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
ctx, "tcp", "eth0", sourceMatcher, destination, true,
)
const ipv6Supported = false
@@ -58,7 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
require.NoError(t, err)
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1"))
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)
+20 -27
View File
@@ -13,9 +13,9 @@ import (
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}
@@ -24,15 +24,14 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = listener.Close()
_ = connection.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
httpClient = newHTTPSClient(destinationTLSName, connection)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
@@ -42,9 +41,9 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = listener.Close()
err = connection.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err))
errs = append(errs, fmt.Errorf("closing connection: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
@@ -55,7 +54,7 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
}
func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort,
connection net.Conn,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
@@ -66,7 +65,7 @@ func newHTTPSClient(destinationTLSName string,
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
httpTransport.DialContext = newConnectionDialContext(connection)
const timeout = 5 * time.Second
return &http.Client{
@@ -75,25 +74,14 @@ func newHTTPSClient(destinationTLSName string,
}
}
func newBoundDialContext(destinationAddress netip.Addr,
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil
}
}
func bindSourcePort(destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error,
func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) (
connection net.Conn, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
@@ -102,14 +90,19 @@ func bindSourcePort(destinationIP netip.Addr) (
bindAddr = netip.AddrFrom16([16]byte{})
}
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
netip.AddrPortFrom(bindAddr, 0)))
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
dialer := &net.Dialer{
Timeout: time.Second,
LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)),
}
connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String())
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()
return listener, sourceAddr, nil
return connection, sourceAddr, nil
}
+1 -1
View File
@@ -106,7 +106,7 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
dohURL *url.URL, dohServerIP netip.Addr,
) (responseMessage *dns.Msg, err error) {
httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP)
if err != nil {
return nil, fmt.Errorf("opening https connection: %w", err)
}