mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-25 13:27:31 +02:00
imporatnt fix 1
This commit is contained in:
@@ -47,7 +47,7 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
|
|||||||
|
|
||||||
selectedIP := resolvedIPs[0]
|
selectedIP := resolvedIPs[0]
|
||||||
|
|
||||||
httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
|
httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
|
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package restrictednet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -34,15 +35,28 @@ func (m listenAddrPortMatcher) String() string {
|
|||||||
|
|
||||||
func Test_Client_OpenHTTPS(t *testing.T) {
|
func Test_Client_OpenHTTPS(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
netConfig := net.ListenConfig{}
|
||||||
|
listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
connection, acceptErr := listener.Accept()
|
||||||
|
if acceptErr == nil {
|
||||||
|
_ = connection.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
firewall := NewMockFirewall(ctrl)
|
firewall := NewMockFirewall(ctrl)
|
||||||
|
|
||||||
destination := netip.MustParseAddrPort("1.2.3.4:443")
|
destination := netip.MustParseAddrPort("127.0.0.1:443")
|
||||||
backgroundContext := context.Background()
|
|
||||||
sourceMatcher := listenAddrPortMatcher{}
|
sourceMatcher := listenAddrPortMatcher{}
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
|
ctx, "tcp", "eth0", sourceMatcher, destination, false,
|
||||||
).DoAndReturn(func(_ context.Context,
|
).DoAndReturn(func(_ context.Context,
|
||||||
_, _ string, source, _ netip.AddrPort, _ bool,
|
_, _ string, source, _ netip.AddrPort, _ bool,
|
||||||
) error {
|
) error {
|
||||||
@@ -50,7 +64,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
|
||||||
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
|
ctx, "tcp", "eth0", sourceMatcher, destination, true,
|
||||||
)
|
)
|
||||||
|
|
||||||
const ipv6Supported = false
|
const ipv6Supported = false
|
||||||
@@ -58,7 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) {
|
|||||||
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
|
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
|
httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, httpClient)
|
require.NotNil(t, httpClient)
|
||||||
require.NotNil(t, cleanup)
|
require.NotNil(t, cleanup)
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ import (
|
|||||||
|
|
||||||
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
|
||||||
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
|
||||||
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr,
|
||||||
) (httpClient *http.Client, cleanup func() error, err error) {
|
) (httpClient *http.Client, cleanup func() error, err error) {
|
||||||
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
|
connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
return nil, nil, fmt.Errorf("binding source port: %w", err)
|
||||||
}
|
}
|
||||||
@@ -24,15 +24,14 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
|||||||
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
|
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
|
||||||
|
|
||||||
const remove = false
|
const remove = false
|
||||||
ctx := context.Background() // it's a quick firewall change, worth not passing a context
|
|
||||||
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
|
||||||
sourceAddrPort, destinationAddrPort, remove)
|
sourceAddrPort, destinationAddrPort, remove)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = listener.Close()
|
_ = connection.Close()
|
||||||
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
|
httpClient = newHTTPSClient(destinationTLSName, connection)
|
||||||
cleanup = func() error {
|
cleanup = func() error {
|
||||||
var errs []error
|
var errs []error
|
||||||
httpClient.CloseIdleConnections()
|
httpClient.CloseIdleConnections()
|
||||||
@@ -42,9 +41,9 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
|
||||||
}
|
}
|
||||||
err = listener.Close()
|
err = connection.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, fmt.Errorf("closing listener: %w", err))
|
errs = append(errs, fmt.Errorf("closing connection: %w", err))
|
||||||
}
|
}
|
||||||
if len(errs) > 0 {
|
if len(errs) > 0 {
|
||||||
return errors.Join(errs...)
|
return errors.Join(errs...)
|
||||||
@@ -55,7 +54,7 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPSClient(destinationTLSName string,
|
func newHTTPSClient(destinationTLSName string,
|
||||||
destinationIP netip.Addr, sourceAddress netip.AddrPort,
|
connection net.Conn,
|
||||||
) *http.Client {
|
) *http.Client {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
|
||||||
httpTransport.Proxy = nil
|
httpTransport.Proxy = nil
|
||||||
@@ -66,7 +65,7 @@ func newHTTPSClient(destinationTLSName string,
|
|||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
ServerName: destinationTLSName,
|
ServerName: destinationTLSName,
|
||||||
}
|
}
|
||||||
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)
|
httpTransport.DialContext = newConnectionDialContext(connection)
|
||||||
|
|
||||||
const timeout = 5 * time.Second
|
const timeout = 5 * time.Second
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
@@ -75,25 +74,14 @@ func newHTTPSClient(destinationTLSName string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBoundDialContext(destinationAddress netip.Addr,
|
func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
sourceAddress netip.AddrPort,
|
|
||||||
) func(ctx context.Context, network, _ string) (net.Conn, error) {
|
|
||||||
const httpsPort = 443
|
|
||||||
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
|
|
||||||
return func(ctx context.Context, network, _ string) (net.Conn, error) {
|
return func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
const timeout = 2 * time.Second
|
|
||||||
dialer := &net.Dialer{Timeout: timeout}
|
|
||||||
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
|
|
||||||
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
|
|
||||||
}
|
|
||||||
return connection, nil
|
return connection, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func bindSourcePort(destinationIP netip.Addr) (
|
func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) (
|
||||||
listener net.Listener, sourceAddr netip.AddrPort, err error,
|
connection net.Conn, sourceAddr netip.AddrPort, err error,
|
||||||
) {
|
) {
|
||||||
var bindAddr netip.Addr
|
var bindAddr netip.Addr
|
||||||
if destinationIP.Is4() {
|
if destinationIP.Is4() {
|
||||||
@@ -102,14 +90,19 @@ func bindSourcePort(destinationIP netip.Addr) (
|
|||||||
bindAddr = netip.AddrFrom16([16]byte{})
|
bindAddr = netip.AddrFrom16([16]byte{})
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
|
const httpsPort = 443
|
||||||
netip.AddrPortFrom(bindAddr, 0)))
|
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: time.Second,
|
||||||
|
LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)),
|
||||||
|
}
|
||||||
|
connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
|
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
|
tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert
|
||||||
sourceAddr = tcpAddr.AddrPort()
|
sourceAddr = tcpAddr.AddrPort()
|
||||||
|
|
||||||
return listener, sourceAddr, nil
|
return connection, sourceAddr, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func (c *Client) resolveOneQuestionType(ctx context.Context,
|
|||||||
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
func (c *Client) doHQuery(ctx context.Context, queryWire []byte,
|
||||||
dohURL *url.URL, dohServerIP netip.Addr,
|
dohURL *url.URL, dohServerIP netip.Addr,
|
||||||
) (responseMessage *dns.Msg, err error) {
|
) (responseMessage *dns.Msg, err error) {
|
||||||
httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP)
|
httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("opening https connection: %w", err)
|
return nil, fmt.Errorf("opening https connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user