diff --git a/.vscode/settings.json b/.vscode/settings.json index f7e46397..2346a691 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,7 @@ // to develop this project. "files.eol": "\n", "editor.formatOnSave": true, - "go.buildTags": "linux", + "go.buildTags": "linux,integration", "go.toolsEnvVars": { "CGO_ENABLED": "0" }, diff --git a/Dockerfile b/Dockerfile index 19e50dcf..3b7d68e8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -276,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ PUID=1000 \ PGID=1000 ENTRYPOINT ["/gluetun-entrypoint"] -EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp +EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp 1080/udp HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck ARG TARGETPLATFORM RUN apk add --no-cache --update -l wget && \ diff --git a/README.md b/README.md index 04dcdbfc..8e446faa 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers - Choose the vpn network protocol, `udp` or `tcp` - Built in firewall kill switch to allow traffic only with needed the VPN servers and LAN devices - Built in Shadowsocks proxy server (protocol based on SOCKS5 with an encryption layer, tunnels TCP+UDP) -- Built in Socks5 proxy server (tunnels TCP) - partial credits to @angelakis and @adjscent +- Built in Socks5 proxy server (tunnels TCP+UDP) - partial credits to @angelakis and @adjscent - Built in HTTP proxy (tunnels HTTP and HTTPS through TCP) - [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md) - [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md) diff --git a/internal/socks5/constants.go b/internal/socks5/constants.go index bb185ad6..afedb555 100644 --- a/internal/socks5/constants.go +++ b/internal/socks5/constants.go @@ -43,7 +43,6 @@ type cmdType byte const ( connect cmdType = 1 - bind cmdType = 2 udpAssociate cmdType = 3 ) @@ -51,8 +50,6 @@ func (c cmdType) String() string { switch c { case connect: return "connect" - case bind: - return "bind" case udpAssociate: return "UDP associate" default: diff --git a/internal/socks5/response.go b/internal/socks5/response.go index b65ee5e9..ffa27f7c 100644 --- a/internal/socks5/response.go +++ b/internal/socks5/response.go @@ -10,7 +10,7 @@ import ( ) // See https://datatracker.ietf.org/doc/html/rfc1928#section-6 -func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam +func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { _, err := writer.Write([]byte{ socksVersion, byte(reply), diff --git a/internal/socks5/server.go b/internal/socks5/server.go index 421c9883..9a5e2202 100644 --- a/internal/socks5/server.go +++ b/internal/socks5/server.go @@ -2,6 +2,7 @@ package socks5 import ( "context" + "errors" "fmt" "net" "sync" @@ -15,12 +16,13 @@ type server struct { logger Logger // internal fields - listener net.Listener + tcpListener net.Listener + udpRouter *udpRouter listening atomic.Bool socksConnCtx context.Context //nolint:containedctx socksConnCancel context.CancelFunc - done <-chan struct{} - stopping atomic.Bool + done <-chan error + stopCh chan<- struct{} } func newServer(settings Settings) *server { @@ -39,19 +41,28 @@ func (s *server) String() string { func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) { s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background()) config := &net.ListenConfig{} - s.listener, err = config.Listen(ctx, "tcp", s.address) + s.tcpListener, err = config.Listen(ctx, "tcp", s.address) if err != nil { - return nil, fmt.Errorf("listening on %s: %w", s.address, err) + return nil, fmt.Errorf("TCP listening on %s: %w", s.address, err) + } + + s.udpRouter, err = newUDPRouter(ctx, s.address, s.logger) + if err != nil { + _ = s.tcpListener.Close() + return nil, fmt.Errorf("creating UDP router: %w", err) } s.listening.Store(true) - s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr()) + s.logger.Infof("SOCKS5 TCP server listening on %s", s.tcpListener.Addr()) + s.logger.Infof("SOCKS5 UDP server listening on %s", s.udpRouter.localAddress()) ready := make(chan struct{}) runErrCh := make(chan error) runErr = runErrCh - done := make(chan struct{}) + done := make(chan error) s.done = done - go s.runServer(ready, runErrCh, done) + stop := make(chan struct{}) + s.stopCh = stop + go s.runServer(ready, runErrCh, stop, done) select { case <-ready: case <-ctx.Done(): @@ -62,61 +73,90 @@ func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) { } func (s *server) runServer(ready chan<- struct{}, - runErrCh chan<- error, done chan<- struct{}, + runErrCh chan<- error, stop <-chan struct{}, done chan<- error, ) { close(ready) defer close(done) - wg := new(sync.WaitGroup) - defer wg.Wait() - dialer := &net.Dialer{} - for { - connection, err := s.listener.Accept() - if err != nil { - if !s.stopping.Load() { - _ = s.stop() - runErrCh <- fmt.Errorf("accepting connection: %w", err) - } - return - } - wg.Add(1) - go func(ctx context.Context, connection net.Conn, - dialer *net.Dialer, wg *sync.WaitGroup, - ) { - defer wg.Done() - socksConn := &socksConn{ - dialer: dialer, - username: s.username, - password: s.password, - clientConn: connection, - logger: s.logger, - } - err := socksConn.run(ctx) + udpErrCh := make(chan error) + go func() { + udpErrCh <- s.udpRouter.run(s.socksConnCtx) + }() + + tcpErrCh := make(chan error) + go func() { + var wg sync.WaitGroup + defer wg.Wait() + + dialer := &net.Dialer{} + for { + connection, err := s.tcpListener.Accept() if err != nil { - s.logger.Infof("running socks connection: %s", err) + s.socksConnCancel() // stop ongoing TCP socks connections - no impact on UDP + tcpErrCh <- fmt.Errorf("accepting connection: %w", err) + return } - }(s.socksConnCtx, connection, dialer, wg) + wg.Go(func() { + connection := connection // capture loop variable + socksConn := &socksConn{ + dialer: dialer, + username: s.username, + password: s.password, + clientConn: connection, + udpRouter: s.udpRouter, + logger: s.logger, + } + err := socksConn.run(s.socksConnCtx) + if err != nil { + s.logger.Infof("running socks connection: %s", err) + } + }) + } + }() + + select { + case <-stop: + s.listening.Store(false) + var errs []error + err := s.tcpListener.Close() + if err != nil { + errs = append(errs, fmt.Errorf("closing TCP listener: %w", err)) + } + // stop ongoing TCP socks connections. This impacts the udpRouter run error when it is being closed. + s.socksConnCancel() + <-tcpErrCh // wait for TCP server to stop + err = s.udpRouter.close() + if err != nil { + errs = append(errs, fmt.Errorf("closing UDP router: %w", err)) + } + <-udpErrCh // wait for UDP router to stop + if len(errs) > 0 { + // Only write to the done channel if the [server.Stop] method is waiting to read from it + done <- errors.Join(errs...) + } + // If no error, the done channel is closed so the error is effectively `nil` + // Note: do NOT write an error the runError channel, since we are stopping the server gracefully. + case err := <-udpErrCh: + _ = s.tcpListener.Close() // stop accepting new TCP connections + s.socksConnCancel() // stop ongoing TCP socks connections + <-tcpErrCh // wait for TCP server to stop + runErrCh <- fmt.Errorf("running UDP router: %w", err) + case err := <-tcpErrCh: + s.socksConnCancel() + _ = s.udpRouter.close() // stop UDP router + <-udpErrCh // wait for UDP router to stop + runErrCh <- fmt.Errorf("running TCP server: %w", err) } } func (s *server) Stop() (err error) { - s.stopping.Store(true) - err = s.stop() - <-s.done // wait for run goroutine to finish - s.stopping.Store(false) - return err -} - -func (s *server) stop() error { - s.listening.Store(false) - err := s.listener.Close() - s.socksConnCancel() // stop ongoing socks connections - return err + close(s.stopCh) + return <-s.done } func (s *server) listeningAddress() net.Addr { if s.listening.Load() { - return s.listener.Addr() + return s.tcpListener.Addr() } return nil } diff --git a/internal/socks5/socks5.go b/internal/socks5/socks5.go index 3eb29ce8..c7595629 100644 --- a/internal/socks5/socks5.go +++ b/internal/socks5/socks5.go @@ -10,6 +10,7 @@ import ( "net/netip" "strconv" "strings" + "sync" ) var ( @@ -23,6 +24,7 @@ type socksConn struct { username string password string clientConn net.Conn + udpRouter *udpRouter logger Logger } @@ -109,11 +111,29 @@ func (c *socksConn) handleRequest(ctx context.Context) error { c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) return err } - if request.command != connect { + + switch request.command { + case connect: + err = c.handleConnectRequest(ctx, socksVersion, request) + if err != nil { + return fmt.Errorf("handling %s request: %w", request.command, err) + } + return nil + case udpAssociate: + err = c.handleUDPAssociateRequest(ctx, socksVersion, request) + if err != nil { + return fmt.Errorf("handling %s request: %w", request.command, err) + } + return nil + default: c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported) return fmt.Errorf("command %s is not supported", request.command) } +} +func (c *socksConn) handleConnectRequest(ctx context.Context, + socksVersion byte, request request, +) error { destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port)) destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress) if err != nil { @@ -176,6 +196,234 @@ func (c *socksConn) handleRequest(ctx context.Context) error { } } +func (c *socksConn) handleUDPAssociateRequest(ctx context.Context, + socksVersion byte, request request, +) error { + expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported) + return fmt.Errorf("deriving expected client address and port from request: %w", err) + } + + bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses() + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return fmt.Errorf("getting udp association addresses: %w", err) + } + + association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return fmt.Errorf("registering udp association: %w", err) + } + defer c.udpRouter.unregisterAssociation(association) + + err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, + bindAddrType, bindAddress, bindPort) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return fmt.Errorf("writing successful %s response: %w", udpAssociate, err) + } + + associationCtx, associationCancel := context.WithCancel(ctx) + defer associationCancel() + + var wg sync.WaitGroup + + wg.Go(func() { + c.udpRouter.runAssociationHandler(associationCtx, association) + }) + + wg.Go(func() { + _, _ = io.Copy(io.Discard, c.clientConn) + associationCancel() + }) + <-associationCtx.Done() + wg.Wait() + return nil +} + +func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) { + switch request.addressType { + case ipv4, ipv6: + expectedClientAddress, parseErr := netip.ParseAddr(request.destination) + if parseErr != nil { + return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr) + } + expectedClientAddress = expectedClientAddress.Unmap() + if !expectedClientAddress.IsUnspecified() { + return netip.AddrPortFrom(expectedClientAddress, request.port), nil + } + return netip.AddrPortFrom(netip.Addr{}, request.port), nil + case domainName: + if request.destination != "" || request.port != 0 { + return netip.AddrPort{}, fmt.Errorf("domain name is not supported for UDP associate destination") + } + return netip.AddrPort{}, nil + default: + return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType) + } +} + +func (c *socksConn) udpAssociationAddresses() (bindAddress string, + bindPort uint16, bindAddrType addrType, err error, +) { + localAddress := c.udpRouter.localAddress().String() + host, portString, err := net.SplitHostPort(localAddress) + if err != nil { + return "", 0, 0, fmt.Errorf("splitting local address: %w", err) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return "", 0, 0, fmt.Errorf("parsing local port: %w", err) + } + bindAddress = host + bindPort = uint16(port) + if isUnspecifiedIPAddress(bindAddress) { + controlLocalAddress := c.clientConn.LocalAddr().String() + controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress) + if splitErr != nil { + return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr) + } + bindAddress = controlLocalHost + } + + ipAddress := net.ParseIP(bindAddress) + if ipAddress == nil { + bindAddrType = domainName + return bindAddress, bindPort, bindAddrType, nil + } + + if ipAddress.To4() != nil { + bindAddrType = ipv4 + } else { + bindAddrType = ipv6 + } + + return bindAddress, bindPort, bindAddrType, nil +} + +func isUnspecifiedIPAddress(address string) bool { + ipAddress, err := netip.ParseAddr(address) + if err != nil { + return false + } + return ipAddress.IsUnspecified() +} + +func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) { + const minimumPacketLength = 4 + if len(packet) < minimumPacketLength { + return "", nil, fmt.Errorf("packet is too short: %d", len(packet)) + } + if packet[0] != 0 || packet[1] != 0 { + return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1]) + } + if packet[2] != 0 { + return "", nil, fmt.Errorf("fragmentation is not supported") + } + + offset := 3 + addressType := addrType(packet[offset]) + offset++ + + switch addressType { + case ipv4: + const ipv4Length = 4 + if len(packet) < offset+ipv4Length+2 { + return "", nil, fmt.Errorf("packet is too short for IPv4 address") + } + var ip [ipv4Length]byte + copy(ip[:], packet[offset:offset+ipv4Length]) + destination = netip.AddrFrom4(ip).String() + offset += ipv4Length + case ipv6: + const ipv6Length = 16 + if len(packet) < offset+ipv6Length+2 { + return "", nil, fmt.Errorf("packet is too short for IPv6 address") + } + var ip [ipv6Length]byte + copy(ip[:], packet[offset:offset+ipv6Length]) + destination = netip.AddrFrom16(ip).String() + offset += ipv6Length + case domainName: + if len(packet) < offset+1 { + return "", nil, fmt.Errorf("packet is too short for domain name length") + } + domainNameLength := int(packet[offset]) + offset++ + if len(packet) < offset+domainNameLength+2 { + return "", nil, fmt.Errorf("packet is too short for domain name") + } + destination = string(packet[offset : offset+domainNameLength]) + offset += domainNameLength + default: + return "", nil, fmt.Errorf("address type is not supported: %d", addressType) + } + + port := binary.BigEndian.Uint16(packet[offset : offset+2]) + destination = net.JoinHostPort(destination, fmt.Sprint(port)) + offset += 2 + payload = packet[offset:] + + return destination, payload, nil +} + +func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort, + payload []byte, +) error { + address := sourceAddrPort.Addr() + if !address.IsValid() { + return errors.New("source address is not valid") + } + + err := writeUDPDatagramSourceAddress(writer, address) + if err != nil { + return fmt.Errorf("writing source address: %w", err) + } + + var portBytes [2]byte + binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port()) + _, err = writer.Write(portBytes[:]) + if err != nil { + return fmt.Errorf("writing destination port: %w", err) + } + + _, err = writer.Write(payload) + if err != nil { + return fmt.Errorf("writing payload: %w", err) + } + + return nil +} + +func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error { + var addrType addrType + var addressBytes []byte + switch { + case address.Is4(): + addrType = ipv4 + array := address.As4() + addressBytes = array[:] + case address.Is6(): + addrType = ipv6 + array := address.As16() + addressBytes = array[:] + default: + return fmt.Errorf("address type is not supported: %v", address) + } + + _, err := writer.Write([]byte{0, 0, 0, byte(addrType)}) + if err != nil { + return fmt.Errorf("writing header: %w", err) + } + _, err = writer.Write(addressBytes) + if err != nil { + return fmt.Errorf("writing IP address: %w", err) + } + return nil +} + // See https://datatracker.ietf.org/doc/html/rfc1928#section-3 func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error { const headerLength = 2 // version + nMethods bytes diff --git a/internal/socks5/socks5_test.go b/internal/socks5/socks5_test.go index 5b16de16..3844805c 100644 --- a/internal/socks5/socks5_test.go +++ b/internal/socks5/socks5_test.go @@ -2,7 +2,10 @@ package socks5 import ( "bytes" + "context" "encoding/binary" + "errors" + "fmt" "io" "net" "strconv" @@ -96,6 +99,178 @@ func TestServerProxy(t *testing.T) { } } +func TestServerProxyTCPAndUDPParallel(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + username string + password string + }{ + "no_auth": {}, + "with_auth": { + username: "user", + password: "pass", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + backendTCPListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + + backendTCPConnChannel := make(chan net.Conn, 1) + go func() { + connection, err := backendTCPListener.Accept() + if err != nil { + return + } + backendTCPConnChannel <- connection + }() + + backendUDPPacketConn, err := (&net.ListenConfig{}).ListenPacket(t.Context(), "udp", "127.0.0.1:0") + require.NoError(t, err) + + server := newServer(Settings{ + Username: testCase.username, + Password: testCase.password, + Address: "127.0.0.1:0", + Logger: noopLogger{}, + }) + _, err = server.Start(t.Context()) + require.NoError(t, err) + t.Cleanup(func() { + _ = server.Stop() + _ = backendTCPListener.Close() + _ = backendUDPPacketConn.Close() + }) + + clientTCPConn := dialSOCKS5(t, server.listeningAddress().String(), + backendTCPListener.Addr().String(), testCase.username, testCase.password) + defer clientTCPConn.Close() + + backendTCPConn := <-backendTCPConnChannel + defer backendTCPConn.Close() + + udpControlConn, clientUDPConn := dialSOCKS5UDPAssociate(t, + server.listeningAddress().String(), testCase.username, testCase.password) + defer udpControlConn.Close() + defer clientUDPConn.Close() + + tcpErrCh := make(chan error, 1) + go func() { + tcpErrCh <- runTCPProxyRoundTrip(clientTCPConn, backendTCPConn) + }() + + udpErrCh := make(chan error, 1) + go func() { + udpErrCh <- runUDPProxyRoundTrip(t.Context(), clientUDPConn, backendUDPPacketConn) + }() + + err = <-tcpErrCh + require.NoError(t, err) + err = <-udpErrCh + require.NoError(t, err) + }) + } +} + +func runTCPProxyRoundTrip(clientTCPConn net.Conn, backendTCPConn net.Conn) error { + clientMessage := []byte("hello from client") + _, err := clientTCPConn.Write(clientMessage) + if err != nil { + return err + } + + received := make([]byte, len(clientMessage)) + _, err = io.ReadFull(backendTCPConn, received) + if err != nil { + return err + } + if !bytes.Equal(clientMessage, received) { + return errors.New("backend did not receive expected TCP payload") + } + + backendMessage := []byte("hello from backend") + _, err = backendTCPConn.Write(backendMessage) + if err != nil { + return err + } + + receivedByClient := make([]byte, len(backendMessage)) + _, err = io.ReadFull(clientTCPConn, receivedByClient) + if err != nil { + return err + } + if !bytes.Equal(backendMessage, receivedByClient) { + return errors.New("client did not receive expected TCP payload") + } + + return nil +} + +func runUDPProxyRoundTrip(ctx context.Context, clientUDPConn *net.UDPConn, backendUDPPacketConn net.PacketConn) error { + udpPayload := []byte("hello from udp client") + udpRequest, err := makeSOCKS5UDPDatagram(backendUDPPacketConn.LocalAddr().String(), udpPayload) + if err != nil { + return err + } + + _, err = clientUDPConn.Write(udpRequest) + if err != nil { + return err + } + + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + err = backendUDPPacketConn.SetReadDeadline(deadline) + if err != nil { + return fmt.Errorf("setting read deadline on backend connection: %w", err) + } + } + const bufferSize = 512 + backendReadBuffer := make([]byte, bufferSize) + packetLength, proxyAddress, err := backendUDPPacketConn.ReadFrom(backendReadBuffer) + if err != nil { + return err + } + if !bytes.Equal(udpPayload, backendReadBuffer[:packetLength]) { + return errors.New("backend did not receive expected UDP payload") + } + + backendUDPReply := []byte("hello from udp backend") + _, err = backendUDPPacketConn.WriteTo(backendUDPReply, proxyAddress) + if err != nil { + return err + } + + if hasDeadline { + err = clientUDPConn.SetReadDeadline(deadline) + if err != nil { + return fmt.Errorf("setting read deadline on client connection: %w", err) + } + } + udpResponseBuffer := make([]byte, 1024) + responseLength, err := clientUDPConn.Read(udpResponseBuffer) + if err != nil { + return err + } + + destinationAddress, udpResponsePayload, err := parseSOCKS5UDPDatagram(udpResponseBuffer[:responseLength]) + if err != nil { + return err + } + + if !bytes.Equal(backendUDPReply, udpResponsePayload) { + return errors.New("client did not receive expected UDP payload") + } + if destinationAddress != backendUDPPacketConn.LocalAddr().String() { + return errors.New("udp response destination address mismatch") + } + + return nil +} + // dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password // subnegotiation) and returns a connected net.Conn ready for data exchange. func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn { @@ -109,6 +284,55 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr) require.NoError(t, err) + negotiateSOCKS5(t, conn, username, password) + + var connectRequest []byte + if ip := net.ParseIP(host).To4(); ip != nil { + connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)} + connectRequest = append(connectRequest, ip...) + } else { + connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))} + connectRequest = append(connectRequest, []byte(host)...) + } + connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec + _, err = conn.Write(connectRequest) + require.NoError(t, err) + + _, err = readSOCKS5ResponseAddress(t, conn) + require.NoError(t, err) + + return conn +} + +func dialSOCKS5UDPAssociate(t *testing.T, proxyAddr, username, password string) (net.Conn, *net.UDPConn) { + t.Helper() + + controlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr) + require.NoError(t, err) + + negotiateSOCKS5(t, controlConn, username, password) + + udpAssociateRequest := []byte{socks5Version, byte(udpAssociate), 0, byte(ipv4), 0, 0, 0, 0, 0, 0} + _, err = controlConn.Write(udpAssociateRequest) + require.NoError(t, err) + + udpProxyAddress, err := readSOCKS5ResponseAddress(t, controlConn) + require.NoError(t, err) + + udpProxyResolvedAddress, err := net.ResolveUDPAddr("udp", udpProxyAddress) + require.NoError(t, err) + + udpConn, err := net.DialUDP("udp", nil, udpProxyResolvedAddress) + require.NoError(t, err) + + return controlConn, udpConn +} + +func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) { + t.Helper() + + var err error + var method authMethod if username != "" || password != "" { method = authUsernamePassword @@ -138,45 +362,146 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0]) require.Equal(t, byte(0), subnegResp[1]) } +} - var connectRequest []byte - if ip := net.ParseIP(host).To4(); ip != nil { - connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)} - connectRequest = append(connectRequest, ip...) - } else { - connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))} - connectRequest = append(connectRequest, []byte(host)...) - } - connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec - _, err = conn.Write(connectRequest) - require.NoError(t, err) +func readSOCKS5ResponseAddress(t *testing.T, conn net.Conn) (address string, err error) { + t.Helper() var responseHeader [4]byte _, err = io.ReadFull(conn, responseHeader[:]) - require.NoError(t, err) - require.Equal(t, socks5Version, responseHeader[0]) - require.Equal(t, byte(succeeded), responseHeader[1]) - - // Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller). - switch addrType(responseHeader[3]) { - case ipv4: - var addrPort [net.IPv4len + 2]byte - _, err = io.ReadFull(conn, addrPort[:]) - require.NoError(t, err) - case ipv6: - var addrPort [net.IPv6len + 2]byte - _, err = io.ReadFull(conn, addrPort[:]) - require.NoError(t, err) - case domainName: - var lenBuf [1]byte - _, err = io.ReadFull(conn, lenBuf[:]) - require.NoError(t, err) - addrPort := make([]byte, int(lenBuf[0])+2) - _, err = io.ReadFull(conn, addrPort) - require.NoError(t, err) + if err != nil { + return "", err + } + if responseHeader[0] != socks5Version { + return "", errors.New("version mismatch") + } + if responseHeader[1] != byte(succeeded) { + return "", errors.New("request was not successful") } - return conn + var host string + switch addrType(responseHeader[3]) { + case ipv4: + addressAndPort := make([]byte, net.IPv4len+2) + _, err = io.ReadFull(conn, addressAndPort) + if err != nil { + return "", err + } + host = net.IP(addressAndPort[:net.IPv4len]).String() + port := binary.BigEndian.Uint16(addressAndPort[net.IPv4len:]) + return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil + case ipv6: + addressAndPort := make([]byte, net.IPv6len+2) + _, err = io.ReadFull(conn, addressAndPort) + if err != nil { + return "", err + } + host = net.IP(addressAndPort[:net.IPv6len]).String() + port := binary.BigEndian.Uint16(addressAndPort[net.IPv6len:]) + return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil + case domainName: + var lengthBuffer [1]byte + _, err = io.ReadFull(conn, lengthBuffer[:]) + if err != nil { + return "", err + } + domainAndPort := make([]byte, int(lengthBuffer[0])+2) + _, err = io.ReadFull(conn, domainAndPort) + if err != nil { + return "", err + } + host = string(domainAndPort[:len(domainAndPort)-2]) + port := binary.BigEndian.Uint16(domainAndPort[len(domainAndPort)-2:]) + return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil + default: + return "", errors.New("unknown address type") + } +} + +func makeSOCKS5UDPDatagram(targetAddress string, payload []byte) ([]byte, error) { + host, portString, err := net.SplitHostPort(targetAddress) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + + datagram := []byte{0, 0, 0} + ipAddress := net.ParseIP(host) + if ipAddress != nil { + if ipAddress.To4() != nil { + datagram = append(datagram, byte(ipv4)) + datagram = append(datagram, ipAddress.To4()...) + } else { + datagram = append(datagram, byte(ipv6)) + datagram = append(datagram, ipAddress.To16()...) + } + } else { + if len(host) > 255 { + return nil, errors.New("domain name too long") + } + datagram = append(datagram, byte(domainName), byte(len(host))) + datagram = append(datagram, []byte(host)...) + } + datagram = binary.BigEndian.AppendUint16(datagram, uint16(port)) + datagram = append(datagram, payload...) + + return datagram, nil +} + +func parseSOCKS5UDPDatagram(datagram []byte) (destinationAddress string, payload []byte, err error) { + if len(datagram) < 4 { + return "", nil, errors.New("datagram too short") + } + if datagram[0] != 0 || datagram[1] != 0 { + return "", nil, errors.New("invalid reserved header") + } + if datagram[2] != 0 { + return "", nil, errors.New("fragments are not supported") + } + + offset := 3 + var host string + switch addrType(datagram[offset]) { + case ipv4: + offset++ + if len(datagram) < offset+net.IPv4len+2 { + return "", nil, errors.New("datagram too short for IPv4") + } + host = net.IP(datagram[offset : offset+net.IPv4len]).String() + offset += net.IPv4len + case ipv6: + offset++ + if len(datagram) < offset+net.IPv6len+2 { + return "", nil, errors.New("datagram too short for IPv6") + } + host = net.IP(datagram[offset : offset+net.IPv6len]).String() + offset += net.IPv6len + case domainName: + offset++ + if len(datagram) < offset+1 { + return "", nil, errors.New("datagram too short for domain length") + } + domainLength := int(datagram[offset]) + offset++ + if len(datagram) < offset+domainLength+2 { + return "", nil, errors.New("datagram too short for domain") + } + host = string(datagram[offset : offset+domainLength]) + offset += domainLength + default: + return "", nil, errors.New("unknown address type") + } + + if len(datagram) < offset+2 { + return "", nil, errors.New("datagram too short for port") + } + port := binary.BigEndian.Uint16(datagram[offset : offset+2]) + offset += 2 + + return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), datagram[offset:], nil } func Test_newServer(t *testing.T) { @@ -224,7 +549,8 @@ func Test_Server_StartStop(t *testing.T) { ctrl := gomock.NewController(t) logger := NewMockLogger(ctrl) - logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any()) + logger.EXPECT().Infof("SOCKS5 TCP server listening on %s", gomock.Any()) + logger.EXPECT().Infof("SOCKS5 UDP server listening on %s", gomock.Any()) server := newServer(Settings{ Address: "127.0.0.1:0", @@ -598,10 +924,6 @@ func Test_cmdType_String(t *testing.T) { cmd: connect, expectedName: "connect", }, - "bind": { - cmd: bind, - expectedName: "bind", - }, "udp_associate": { cmd: udpAssociate, expectedName: "UDP associate", @@ -620,3 +942,80 @@ func Test_cmdType_String(t *testing.T) { }) } } + +func Test_socksConn_udpAssociationAddresses(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + routerAddress string + expectAddressFromConn bool + expectedAddress string + }{ + "wildcard_router_address_uses_control_connection_local_ip": { + routerAddress: ":0", + expectAddressFromConn: true, + }, + "concrete_router_address_is_kept": { + routerAddress: "127.0.0.1:0", + expectedAddress: "127.0.0.1", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + router, err := newUDPRouter(t.Context(), testCase.routerAddress, noopLogger{}) + require.NoError(t, err) + t.Cleanup(func() { + err := router.close() + assert.NoError(t, err) + }) + + controlListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + err := controlListener.Close() + assert.NoError(t, err) + }) + + acceptedConnCh := make(chan net.Conn, 1) + go func() { + acceptedConn, acceptErr := controlListener.Accept() + if acceptErr != nil { + return + } + acceptedConnCh <- acceptedConn + }() + + clientControlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", controlListener.Addr().String()) + require.NoError(t, err) + defer clientControlConn.Close() + + serverControlConn := <-acceptedConnCh + defer serverControlConn.Close() + + socksConnection := &socksConn{ + clientConn: clientControlConn, + udpRouter: router, + } + bindAddress, bindPort, bindAddrType, err := socksConnection.udpAssociationAddresses() + require.NoError(t, err) + + if testCase.expectAddressFromConn { + clientLocalHost, _, err := net.SplitHostPort(clientControlConn.LocalAddr().String()) + require.NoError(t, err) + assert.Equal(t, clientLocalHost, bindAddress) + } else { + assert.Equal(t, testCase.expectedAddress, bindAddress) + } + + _, routerPortString, err := net.SplitHostPort(router.localAddress().String()) + require.NoError(t, err) + routerPort, err := strconv.ParseUint(routerPortString, 10, 16) + require.NoError(t, err) + assert.Equal(t, uint16(routerPort), bindPort) + assert.Equal(t, ipv4, bindAddrType) + }) + } +} diff --git a/internal/socks5/udp_router.go b/internal/socks5/udp_router.go new file mode 100644 index 00000000..e4ffef4a --- /dev/null +++ b/internal/socks5/udp_router.go @@ -0,0 +1,370 @@ +package socks5 + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync" +) + +type udpAssociation struct { + id uint64 + clientAddrPort netip.AddrPort + expectedAddrPort netip.AddrPort + controlConnAddr netip.Addr + packetCh chan *bytes.Buffer +} + +type udpRouter struct { + logger Logger + + listener net.PacketConn + mutex sync.Mutex + bufferPool sync.Pool + nextAssociationID uint64 + clientAddrPortToAssociation map[netip.AddrPort]udpAssociation + clientIPToPendingAssociations map[netip.Addr][]udpAssociation + associationIDToClientAddrPort map[uint64]netip.AddrPort +} + +const ( + maxUDPPacketLength = 65535 + maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2 + pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead +) + +func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) { + config := &net.ListenConfig{} + listener, err := config.ListenPacket(ctx, "udp", address) + if err != nil { + return nil, fmt.Errorf("UDP listening: %w", err) + } + + return &udpRouter{ + logger: logger, + listener: listener, + bufferPool: sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity)) + }, + }, + nextAssociationID: 1, + clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation), + clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation), + associationIDToClientAddrPort: make(map[uint64]netip.AddrPort), + }, nil +} + +func (r *udpRouter) localAddress() net.Addr { + return r.listener.LocalAddr() +} + +func (r *udpRouter) close() error { + return r.listener.Close() +} + +func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) { + controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String()) + if err != nil { + return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err) + } + controlConnAddr := controlConnAddrPort.Addr().Unmap() + + r.mutex.Lock() + defer r.mutex.Unlock() + + const udpPacketChannelBuffer = 2 + associationID := r.nextAssociationID + r.nextAssociationID++ + + association := udpAssociation{ + id: associationID, + expectedAddrPort: expectedAddrPort, + controlConnAddr: controlConnAddr, + packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer), + } + + if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 { + association.clientAddrPort = expectedAddrPort + r.clientAddrPortToAssociation[association.clientAddrPort] = association + r.associationIDToClientAddrPort[association.id] = association.clientAddrPort + return association, nil + } + + pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr] + pendingAssociations = append(pendingAssociations, association) + r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations + + return association, nil +} + +func (r *udpRouter) unregisterAssociation(association udpAssociation) { + r.mutex.Lock() + defer r.mutex.Unlock() + + clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id] + if hasClientAddress { + delete(r.associationIDToClientAddrPort, association.id) + delete(r.clientAddrPortToAssociation, clientAddrPort) + } + + pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr] + for i, pendingAssociation := range pendingAssociations { + if pendingAssociation.id == association.id { + pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...) + break + } + } + if len(pendingAssociations) == 0 { + delete(r.clientIPToPendingAssociations, association.controlConnAddr) + } else { + r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations + } +} + +func (r *udpRouter) run(ctx context.Context) error { + packetBuffer := make([]byte, maxUDPPacketLength) + + for { + packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer) + if err != nil { + if ctx.Err() != nil && errors.Is(err, net.ErrClosed) { + return nil + } + return fmt.Errorf("reading UDP packet: %w", err) + } + + sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress) + if err != nil { + r.logger.Warnf("parsing source address: %s", err) + continue + } + buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert + buffer.Reset() + _, err = buffer.Write(packetBuffer[:packetLength]) + if err != nil { + r.bufferPool.Put(buffer) + r.logger.Warnf("buffering packet: %s", err) + continue + } + err = r.routePacket(sourceAddrPort, buffer) + if err != nil { + r.logger.Warnf("failed routing UDP packet: %s", err) + } + } +} + +func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error { + r.mutex.Lock() + association, packetFromClient := r.findClientAssociation(sourceAddrPort) + r.mutex.Unlock() + + if !packetFromClient { + r.bufferPool.Put(packet) + return nil + } + + select { + case association.packetCh <- packet: + return nil + default: + r.bufferPool.Put(packet) + return errors.New("association packet queue full") + } +} + +func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) ( + association udpAssociation, ok bool, +) { + association, ok = r.clientAddrPortToAssociation[sourceAddrPort] + if ok { + return association, true + } + sourceAddr := sourceAddrPort.Addr() + + pendingAssociations := r.clientIPToPendingAssociations[sourceAddr] + if len(pendingAssociations) == 0 { + return udpAssociation{}, false + } + + index := -1 + for i, pendingAssociation := range pendingAssociations { + if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) { + association = pendingAssociation + index = i + break + } + } + if index == -1 { + return udpAssociation{}, false + } + + r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...) + if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 { + delete(r.clientIPToPendingAssociations, sourceAddr) + } + + association.clientAddrPort = sourceAddrPort + r.clientAddrPortToAssociation[sourceAddrPort] = association + r.associationIDToClientAddrPort[association.id] = sourceAddrPort + + return association, true +} + +func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool { + switch { + case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr(): + return false + case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port(): + return false + } + return true +} + +func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) ( + clientAddrPort netip.AddrPort, ok bool, +) { + r.mutex.Lock() + defer r.mutex.Unlock() + + clientAddrPort, ok = r.associationIDToClientAddrPort[associationID] + return clientAddrPort, ok +} + +func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) { + config := &net.ListenConfig{} + socket, err := config.ListenPacket(ctx, "udp", ":0") + if err != nil { + r.logger.Warnf("creating per-association UDP socket: %s", err) + return + } + defer socket.Close() + + go closeSocketOnContextDone(ctx, socket) + + packetBuffer := make([]byte, maxUDPPacketLength) + + forwardDoneCh := make(chan struct{}) + go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh) + + for { + packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer) + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + <-forwardDoneCh + return + } + r.logger.Warnf("reading from per-association UDP socket: %s", err) + continue + } + + sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress) + if err != nil { + r.logger.Warnf("parsing source address from destination: %s", err) + continue + } + + buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert + buffer.Reset() + err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength]) + if err != nil { + r.bufferPool.Put(buffer) + r.logger.Warnf("encoding response datagram: %s", err) + continue + } + + clientAddrPort, found := r.clientAddrPortForAssociation(association.id) + if !found { + r.bufferPool.Put(buffer) + r.logger.Warnf("client address not found for association id %d", association.id) + continue + } + + clientUDPAddress := &net.UDPAddr{ + IP: clientAddrPort.Addr().AsSlice(), + Port: int(clientAddrPort.Port()), + } + _, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress) + r.bufferPool.Put(buffer) + if err != nil { + r.logger.Warnf("writing response to client: %s", err) + } + } +} + +func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) { + <-ctx.Done() + _ = socket.Close() +} + +func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn, + packetCh <-chan *bytes.Buffer, done chan<- struct{}, +) { + defer close(done) + + for { + select { + case <-ctx.Done(): + return + case buffer, ok := <-packetCh: + if !ok { + return + } + + err := r.writeClientPacketToDestination(ctx, socket, buffer) + r.bufferPool.Put(buffer) + if err != nil { + r.logger.Warnf("forwarding client packet to destination: %s", err) + } + } + } +} + +func (r *udpRouter) writeClientPacketToDestination(ctx context.Context, + socket net.PacketConn, packet *bytes.Buffer, +) error { + destination, payload, err := decodeUDPDatagram(packet.Bytes()) + if err != nil { + return fmt.Errorf("decoding UDP datagram: %w", err) + } + + host, portStr, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Errorf("splitting destination host and port: %w", err) + } + + if _, err := netip.ParseAddr(host); err != nil { // domain name + addrs, err := net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return fmt.Errorf("resolving destination host: %w", err) + } + if len(addrs) == 0 { + return fmt.Errorf("resolving destination host: no addresses found for %q", host) + } + + destination = net.JoinHostPort(addrs[0], portStr) + } + + resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination) + if err != nil { + return fmt.Errorf("resolving destination UDP address: %w", err) + } + + _, err = socket.WriteTo(payload, resolvedDestinationUDPAddress) + if err != nil && ctx.Err() == nil { + return fmt.Errorf("writing payload to destination: %w", err) + } + + return nil +} + +func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) { + addrPort, err := netip.ParseAddrPort(addr.String()) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err) + } + return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil +} diff --git a/internal/socks5/udp_router_integration_test.go b/internal/socks5/udp_router_integration_test.go new file mode 100644 index 00000000..cecfa506 --- /dev/null +++ b/internal/socks5/udp_router_integration_test.go @@ -0,0 +1,162 @@ +//go:build integration + +package socks5 + +import ( + "bytes" + "context" + "math/rand/v2" + "net" + "net/netip" + "strconv" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_udpRouter_ResolveGithubFromCloudflareDNS(t *testing.T) { + ctx := t.Context() + var cancel context.CancelFunc + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + const deadlineBuffer = 500 * time.Millisecond + deadline = deadline.Add(-deadlineBuffer) + } else { + const defaultTimeout = 10 * time.Second + deadline = time.Now().Add(defaultTimeout) + } + ctx, cancel = context.WithDeadline(ctx, deadline) + + ctrl := gomock.NewController(t) + logger := NewMockLogger(ctrl) + + router, err := newUDPRouter(ctx, "127.0.0.1:0", logger) + require.NoError(t, err) + + routerRunErrCh := make(chan error) + go func() { + routerRunErrCh <- router.run(ctx) + }() + + t.Cleanup(func() { + cancel() + err := router.close() + assert.NoError(t, err, "closing router") + runErr := <-routerRunErrCh + assert.NoError(t, runErr) + }) + + controlListener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + err := controlListener.Close() + assert.NoError(t, err, "closing control listener") + }) + + acceptedConnCh := make(chan net.Conn) + go func() { + acceptedConn, acceptErr := controlListener.Accept() + assert.NoError(t, acceptErr, "accepting control connection") + if acceptErr != nil { + return + } + acceptedConnCh <- acceptedConn + }() + + clientControlConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", controlListener.Addr().String()) + require.NoError(t, err) + t.Cleanup(func() { + err = clientControlConn.Close() + assert.NoError(t, err, "closing client control connection") + }) + + serverControlConn := <-acceptedConnCh + t.Cleanup(func() { + err := serverControlConn.Close() + assert.NoError(t, err, "closing server control connection") + }) + + association, err := router.registerAssociation(serverControlConn, netip.AddrPort{}) + require.NoError(t, err) + t.Cleanup(func() { + router.unregisterAssociation(association) + }) + + associationCtx, associationCancel := context.WithCancel(ctx) + handlerDoneCh := make(chan struct{}) + go func() { + router.runAssociationHandler(associationCtx, association) + close(handlerDoneCh) + }() + t.Cleanup(func() { + associationCancel() + <-handlerDoneCh + }) + + udpRouterAddress, err := net.ResolveUDPAddr("udp", router.localAddress().String()) + require.NoError(t, err) + + clientUDPConn, err := net.DialUDP("udp", nil, udpRouterAddress) + require.NoError(t, err) + t.Cleanup(func() { + err := clientUDPConn.Close() + assert.NoError(t, err, "closing client UDP connection") + }) + + queryID := uint16(rand.Uint()) + dnsRequest := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: queryID, + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: dns.Fqdn("github.com"), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }}, + } + dnsQuery, err := dnsRequest.Pack() + require.NoError(t, err) + + targetAddrPort := netip.MustParseAddrPort("1.1.1.1:53") + socksDatagramBuffer := bytes.NewBuffer(nil) + err = encodeUDPDatagramToBuffer(socksDatagramBuffer, targetAddrPort, dnsQuery) + require.NoError(t, err) + socksDatagram := socksDatagramBuffer.Bytes() + + err = clientUDPConn.SetDeadline(deadline) + require.NoError(t, err) + + _, err = clientUDPConn.Write(socksDatagram) + require.NoError(t, err) + + responseBuffer := make([]byte, maxUDPPacketLength) + responseLength, err := clientUDPConn.Read(responseBuffer) + require.NoError(t, err) + + responseDestination, responsePayload, err := decodeUDPDatagram(responseBuffer[:responseLength]) + require.NoError(t, err) + + responseHost, responsePortString, err := net.SplitHostPort(responseDestination) + require.NoError(t, err) + responsePort, err := strconv.ParseUint(responsePortString, 10, 16) + require.NoError(t, err) + assert.Equal(t, uint64(53), responsePort) + assert.NotEmpty(t, responseHost) + + dnsResponse := new(dns.Msg) + err = dnsResponse.Unpack(responsePayload) + require.NoError(t, err) + assert.Equal(t, queryID, dnsResponse.Id) + assert.True(t, dnsResponse.Response) + assert.Equal(t, dns.RcodeSuccess, dnsResponse.Rcode) + require.NotEmpty(t, dnsResponse.Question) + assert.Equal(t, dns.Fqdn("github.com"), dnsResponse.Question[0].Name) + assert.Equal(t, uint16(dns.TypeA), dnsResponse.Question[0].Qtype) + assert.NotEmpty(t, dnsResponse.Answer) + require.NoError(t, err) +}