mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-15 16:04:08 +02:00
feat(socks5): UDP proxying (#3353)
This commit is contained in:
+437
-38
@@ -2,7 +2,10 @@ package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -96,6 +99,178 @@ func TestServerProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerProxyTCPAndUDPParallel(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]struct {
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
"no_auth": {},
|
||||
"with_auth": {
|
||||
username: "user",
|
||||
password: "pass",
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backendTCPListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
backendTCPConnChannel := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
connection, err := backendTCPListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
backendTCPConnChannel <- connection
|
||||
}()
|
||||
|
||||
backendUDPPacketConn, err := (&net.ListenConfig{}).ListenPacket(t.Context(), "udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
server := newServer(Settings{
|
||||
Username: testCase.username,
|
||||
Password: testCase.password,
|
||||
Address: "127.0.0.1:0",
|
||||
Logger: noopLogger{},
|
||||
})
|
||||
_, err = server.Start(t.Context())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = server.Stop()
|
||||
_ = backendTCPListener.Close()
|
||||
_ = backendUDPPacketConn.Close()
|
||||
})
|
||||
|
||||
clientTCPConn := dialSOCKS5(t, server.listeningAddress().String(),
|
||||
backendTCPListener.Addr().String(), testCase.username, testCase.password)
|
||||
defer clientTCPConn.Close()
|
||||
|
||||
backendTCPConn := <-backendTCPConnChannel
|
||||
defer backendTCPConn.Close()
|
||||
|
||||
udpControlConn, clientUDPConn := dialSOCKS5UDPAssociate(t,
|
||||
server.listeningAddress().String(), testCase.username, testCase.password)
|
||||
defer udpControlConn.Close()
|
||||
defer clientUDPConn.Close()
|
||||
|
||||
tcpErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
tcpErrCh <- runTCPProxyRoundTrip(clientTCPConn, backendTCPConn)
|
||||
}()
|
||||
|
||||
udpErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
udpErrCh <- runUDPProxyRoundTrip(t.Context(), clientUDPConn, backendUDPPacketConn)
|
||||
}()
|
||||
|
||||
err = <-tcpErrCh
|
||||
require.NoError(t, err)
|
||||
err = <-udpErrCh
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runTCPProxyRoundTrip(clientTCPConn net.Conn, backendTCPConn net.Conn) error {
|
||||
clientMessage := []byte("hello from client")
|
||||
_, err := clientTCPConn.Write(clientMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
received := make([]byte, len(clientMessage))
|
||||
_, err = io.ReadFull(backendTCPConn, received)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(clientMessage, received) {
|
||||
return errors.New("backend did not receive expected TCP payload")
|
||||
}
|
||||
|
||||
backendMessage := []byte("hello from backend")
|
||||
_, err = backendTCPConn.Write(backendMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
receivedByClient := make([]byte, len(backendMessage))
|
||||
_, err = io.ReadFull(clientTCPConn, receivedByClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(backendMessage, receivedByClient) {
|
||||
return errors.New("client did not receive expected TCP payload")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runUDPProxyRoundTrip(ctx context.Context, clientUDPConn *net.UDPConn, backendUDPPacketConn net.PacketConn) error {
|
||||
udpPayload := []byte("hello from udp client")
|
||||
udpRequest, err := makeSOCKS5UDPDatagram(backendUDPPacketConn.LocalAddr().String(), udpPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = clientUDPConn.Write(udpRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
if hasDeadline {
|
||||
err = backendUDPPacketConn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting read deadline on backend connection: %w", err)
|
||||
}
|
||||
}
|
||||
const bufferSize = 512
|
||||
backendReadBuffer := make([]byte, bufferSize)
|
||||
packetLength, proxyAddress, err := backendUDPPacketConn.ReadFrom(backendReadBuffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(udpPayload, backendReadBuffer[:packetLength]) {
|
||||
return errors.New("backend did not receive expected UDP payload")
|
||||
}
|
||||
|
||||
backendUDPReply := []byte("hello from udp backend")
|
||||
_, err = backendUDPPacketConn.WriteTo(backendUDPReply, proxyAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hasDeadline {
|
||||
err = clientUDPConn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting read deadline on client connection: %w", err)
|
||||
}
|
||||
}
|
||||
udpResponseBuffer := make([]byte, 1024)
|
||||
responseLength, err := clientUDPConn.Read(udpResponseBuffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destinationAddress, udpResponsePayload, err := parseSOCKS5UDPDatagram(udpResponseBuffer[:responseLength])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(backendUDPReply, udpResponsePayload) {
|
||||
return errors.New("client did not receive expected UDP payload")
|
||||
}
|
||||
if destinationAddress != backendUDPPacketConn.LocalAddr().String() {
|
||||
return errors.New("udp response destination address mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password
|
||||
// subnegotiation) and returns a connected net.Conn ready for data exchange.
|
||||
func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn {
|
||||
@@ -109,6 +284,55 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
|
||||
conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
negotiateSOCKS5(t, conn, username, password)
|
||||
|
||||
var connectRequest []byte
|
||||
if ip := net.ParseIP(host).To4(); ip != nil {
|
||||
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
|
||||
connectRequest = append(connectRequest, ip...)
|
||||
} else {
|
||||
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
|
||||
connectRequest = append(connectRequest, []byte(host)...)
|
||||
}
|
||||
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
|
||||
_, err = conn.Write(connectRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = readSOCKS5ResponseAddress(t, conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func dialSOCKS5UDPAssociate(t *testing.T, proxyAddr, username, password string) (net.Conn, *net.UDPConn) {
|
||||
t.Helper()
|
||||
|
||||
controlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
negotiateSOCKS5(t, controlConn, username, password)
|
||||
|
||||
udpAssociateRequest := []byte{socks5Version, byte(udpAssociate), 0, byte(ipv4), 0, 0, 0, 0, 0, 0}
|
||||
_, err = controlConn.Write(udpAssociateRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
udpProxyAddress, err := readSOCKS5ResponseAddress(t, controlConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
udpProxyResolvedAddress, err := net.ResolveUDPAddr("udp", udpProxyAddress)
|
||||
require.NoError(t, err)
|
||||
|
||||
udpConn, err := net.DialUDP("udp", nil, udpProxyResolvedAddress)
|
||||
require.NoError(t, err)
|
||||
|
||||
return controlConn, udpConn
|
||||
}
|
||||
|
||||
func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
|
||||
var method authMethod
|
||||
if username != "" || password != "" {
|
||||
method = authUsernamePassword
|
||||
@@ -138,45 +362,146 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
|
||||
require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0])
|
||||
require.Equal(t, byte(0), subnegResp[1])
|
||||
}
|
||||
}
|
||||
|
||||
var connectRequest []byte
|
||||
if ip := net.ParseIP(host).To4(); ip != nil {
|
||||
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
|
||||
connectRequest = append(connectRequest, ip...)
|
||||
} else {
|
||||
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
|
||||
connectRequest = append(connectRequest, []byte(host)...)
|
||||
}
|
||||
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
|
||||
_, err = conn.Write(connectRequest)
|
||||
require.NoError(t, err)
|
||||
func readSOCKS5ResponseAddress(t *testing.T, conn net.Conn) (address string, err error) {
|
||||
t.Helper()
|
||||
|
||||
var responseHeader [4]byte
|
||||
_, err = io.ReadFull(conn, responseHeader[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, socks5Version, responseHeader[0])
|
||||
require.Equal(t, byte(succeeded), responseHeader[1])
|
||||
|
||||
// Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller).
|
||||
switch addrType(responseHeader[3]) {
|
||||
case ipv4:
|
||||
var addrPort [net.IPv4len + 2]byte
|
||||
_, err = io.ReadFull(conn, addrPort[:])
|
||||
require.NoError(t, err)
|
||||
case ipv6:
|
||||
var addrPort [net.IPv6len + 2]byte
|
||||
_, err = io.ReadFull(conn, addrPort[:])
|
||||
require.NoError(t, err)
|
||||
case domainName:
|
||||
var lenBuf [1]byte
|
||||
_, err = io.ReadFull(conn, lenBuf[:])
|
||||
require.NoError(t, err)
|
||||
addrPort := make([]byte, int(lenBuf[0])+2)
|
||||
_, err = io.ReadFull(conn, addrPort)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if responseHeader[0] != socks5Version {
|
||||
return "", errors.New("version mismatch")
|
||||
}
|
||||
if responseHeader[1] != byte(succeeded) {
|
||||
return "", errors.New("request was not successful")
|
||||
}
|
||||
|
||||
return conn
|
||||
var host string
|
||||
switch addrType(responseHeader[3]) {
|
||||
case ipv4:
|
||||
addressAndPort := make([]byte, net.IPv4len+2)
|
||||
_, err = io.ReadFull(conn, addressAndPort)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
host = net.IP(addressAndPort[:net.IPv4len]).String()
|
||||
port := binary.BigEndian.Uint16(addressAndPort[net.IPv4len:])
|
||||
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
|
||||
case ipv6:
|
||||
addressAndPort := make([]byte, net.IPv6len+2)
|
||||
_, err = io.ReadFull(conn, addressAndPort)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
host = net.IP(addressAndPort[:net.IPv6len]).String()
|
||||
port := binary.BigEndian.Uint16(addressAndPort[net.IPv6len:])
|
||||
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
|
||||
case domainName:
|
||||
var lengthBuffer [1]byte
|
||||
_, err = io.ReadFull(conn, lengthBuffer[:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
domainAndPort := make([]byte, int(lengthBuffer[0])+2)
|
||||
_, err = io.ReadFull(conn, domainAndPort)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
host = string(domainAndPort[:len(domainAndPort)-2])
|
||||
port := binary.BigEndian.Uint16(domainAndPort[len(domainAndPort)-2:])
|
||||
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
|
||||
default:
|
||||
return "", errors.New("unknown address type")
|
||||
}
|
||||
}
|
||||
|
||||
func makeSOCKS5UDPDatagram(targetAddress string, payload []byte) ([]byte, error) {
|
||||
host, portString, err := net.SplitHostPort(targetAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datagram := []byte{0, 0, 0}
|
||||
ipAddress := net.ParseIP(host)
|
||||
if ipAddress != nil {
|
||||
if ipAddress.To4() != nil {
|
||||
datagram = append(datagram, byte(ipv4))
|
||||
datagram = append(datagram, ipAddress.To4()...)
|
||||
} else {
|
||||
datagram = append(datagram, byte(ipv6))
|
||||
datagram = append(datagram, ipAddress.To16()...)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return nil, errors.New("domain name too long")
|
||||
}
|
||||
datagram = append(datagram, byte(domainName), byte(len(host)))
|
||||
datagram = append(datagram, []byte(host)...)
|
||||
}
|
||||
datagram = binary.BigEndian.AppendUint16(datagram, uint16(port))
|
||||
datagram = append(datagram, payload...)
|
||||
|
||||
return datagram, nil
|
||||
}
|
||||
|
||||
func parseSOCKS5UDPDatagram(datagram []byte) (destinationAddress string, payload []byte, err error) {
|
||||
if len(datagram) < 4 {
|
||||
return "", nil, errors.New("datagram too short")
|
||||
}
|
||||
if datagram[0] != 0 || datagram[1] != 0 {
|
||||
return "", nil, errors.New("invalid reserved header")
|
||||
}
|
||||
if datagram[2] != 0 {
|
||||
return "", nil, errors.New("fragments are not supported")
|
||||
}
|
||||
|
||||
offset := 3
|
||||
var host string
|
||||
switch addrType(datagram[offset]) {
|
||||
case ipv4:
|
||||
offset++
|
||||
if len(datagram) < offset+net.IPv4len+2 {
|
||||
return "", nil, errors.New("datagram too short for IPv4")
|
||||
}
|
||||
host = net.IP(datagram[offset : offset+net.IPv4len]).String()
|
||||
offset += net.IPv4len
|
||||
case ipv6:
|
||||
offset++
|
||||
if len(datagram) < offset+net.IPv6len+2 {
|
||||
return "", nil, errors.New("datagram too short for IPv6")
|
||||
}
|
||||
host = net.IP(datagram[offset : offset+net.IPv6len]).String()
|
||||
offset += net.IPv6len
|
||||
case domainName:
|
||||
offset++
|
||||
if len(datagram) < offset+1 {
|
||||
return "", nil, errors.New("datagram too short for domain length")
|
||||
}
|
||||
domainLength := int(datagram[offset])
|
||||
offset++
|
||||
if len(datagram) < offset+domainLength+2 {
|
||||
return "", nil, errors.New("datagram too short for domain")
|
||||
}
|
||||
host = string(datagram[offset : offset+domainLength])
|
||||
offset += domainLength
|
||||
default:
|
||||
return "", nil, errors.New("unknown address type")
|
||||
}
|
||||
|
||||
if len(datagram) < offset+2 {
|
||||
return "", nil, errors.New("datagram too short for port")
|
||||
}
|
||||
port := binary.BigEndian.Uint16(datagram[offset : offset+2])
|
||||
offset += 2
|
||||
|
||||
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), datagram[offset:], nil
|
||||
}
|
||||
|
||||
func Test_newServer(t *testing.T) {
|
||||
@@ -224,7 +549,8 @@ func Test_Server_StartStop(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any())
|
||||
logger.EXPECT().Infof("SOCKS5 TCP server listening on %s", gomock.Any())
|
||||
logger.EXPECT().Infof("SOCKS5 UDP server listening on %s", gomock.Any())
|
||||
|
||||
server := newServer(Settings{
|
||||
Address: "127.0.0.1:0",
|
||||
@@ -598,10 +924,6 @@ func Test_cmdType_String(t *testing.T) {
|
||||
cmd: connect,
|
||||
expectedName: "connect",
|
||||
},
|
||||
"bind": {
|
||||
cmd: bind,
|
||||
expectedName: "bind",
|
||||
},
|
||||
"udp_associate": {
|
||||
cmd: udpAssociate,
|
||||
expectedName: "UDP associate",
|
||||
@@ -620,3 +942,80 @@ func Test_cmdType_String(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_socksConn_udpAssociationAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
routerAddress string
|
||||
expectAddressFromConn bool
|
||||
expectedAddress string
|
||||
}{
|
||||
"wildcard_router_address_uses_control_connection_local_ip": {
|
||||
routerAddress: ":0",
|
||||
expectAddressFromConn: true,
|
||||
},
|
||||
"concrete_router_address_is_kept": {
|
||||
routerAddress: "127.0.0.1:0",
|
||||
expectedAddress: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router, err := newUDPRouter(t.Context(), testCase.routerAddress, noopLogger{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := router.close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
controlListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := controlListener.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
acceptedConnCh := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
acceptedConn, acceptErr := controlListener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
acceptedConnCh <- acceptedConn
|
||||
}()
|
||||
|
||||
clientControlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", controlListener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientControlConn.Close()
|
||||
|
||||
serverControlConn := <-acceptedConnCh
|
||||
defer serverControlConn.Close()
|
||||
|
||||
socksConnection := &socksConn{
|
||||
clientConn: clientControlConn,
|
||||
udpRouter: router,
|
||||
}
|
||||
bindAddress, bindPort, bindAddrType, err := socksConnection.udpAssociationAddresses()
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.expectAddressFromConn {
|
||||
clientLocalHost, _, err := net.SplitHostPort(clientControlConn.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, clientLocalHost, bindAddress)
|
||||
} else {
|
||||
assert.Equal(t, testCase.expectedAddress, bindAddress)
|
||||
}
|
||||
|
||||
_, routerPortString, err := net.SplitHostPort(router.localAddress().String())
|
||||
require.NoError(t, err)
|
||||
routerPort, err := strconv.ParseUint(routerPortString, 10, 16)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint16(routerPort), bindPort)
|
||||
assert.Equal(t, ipv4, bindAddrType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user