feat(socks5): UDP proxying (#3353)

This commit is contained in:
Quentin McGaw
2026-06-11 09:32:38 -04:00
committed by GitHub
parent acab89b91a
commit 6d84462f00
10 changed files with 1311 additions and 95 deletions
+1 -1
View File
@@ -3,7 +3,7 @@
// to develop this project.
"files.eol": "\n",
"editor.formatOnSave": true,
"go.buildTags": "linux",
"go.buildTags": "linux,integration",
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
+1 -1
View File
@@ -276,7 +276,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
PUID=1000 \
PGID=1000
ENTRYPOINT ["/gluetun-entrypoint"]
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp 1080/tcp 1080/udp
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck
ARG TARGETPLATFORM
RUN apk add --no-cache --update -l wget && \
+1 -1
View File
@@ -73,7 +73,7 @@ Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
- Choose the vpn network protocol, `udp` or `tcp`
- Built in firewall kill switch to allow traffic only with needed the VPN servers and LAN devices
- Built in Shadowsocks proxy server (protocol based on SOCKS5 with an encryption layer, tunnels TCP+UDP)
- Built in Socks5 proxy server (tunnels TCP) - partial credits to @angelakis and @adjscent
- Built in Socks5 proxy server (tunnels TCP+UDP) - partial credits to @angelakis and @adjscent
- Built in HTTP proxy (tunnels HTTP and HTTPS through TCP)
- [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md)
- [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md)
-3
View File
@@ -43,7 +43,6 @@ type cmdType byte
const (
connect cmdType = 1
bind cmdType = 2
udpAssociate cmdType = 3
)
@@ -51,8 +50,6 @@ func (c cmdType) String() string {
switch c {
case connect:
return "connect"
case bind:
return "bind"
case udpAssociate:
return "UDP associate"
default:
+1 -1
View File
@@ -10,7 +10,7 @@ import (
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { //nolint:unparam
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) {
_, err := writer.Write([]byte{
socksVersion,
byte(reply),
+89 -49
View File
@@ -2,6 +2,7 @@ package socks5
import (
"context"
"errors"
"fmt"
"net"
"sync"
@@ -15,12 +16,13 @@ type server struct {
logger Logger
// internal fields
listener net.Listener
tcpListener net.Listener
udpRouter *udpRouter
listening atomic.Bool
socksConnCtx context.Context //nolint:containedctx
socksConnCancel context.CancelFunc
done <-chan struct{}
stopping atomic.Bool
done <-chan error
stopCh chan<- struct{}
}
func newServer(settings Settings) *server {
@@ -39,19 +41,28 @@ func (s *server) String() string {
func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
config := &net.ListenConfig{}
s.listener, err = config.Listen(ctx, "tcp", s.address)
s.tcpListener, err = config.Listen(ctx, "tcp", s.address)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
return nil, fmt.Errorf("TCP listening on %s: %w", s.address, err)
}
s.udpRouter, err = newUDPRouter(ctx, s.address, s.logger)
if err != nil {
_ = s.tcpListener.Close()
return nil, fmt.Errorf("creating UDP router: %w", err)
}
s.listening.Store(true)
s.logger.Infof("SOCKS5 server listening on %s", s.listener.Addr())
s.logger.Infof("SOCKS5 TCP server listening on %s", s.tcpListener.Addr())
s.logger.Infof("SOCKS5 UDP server listening on %s", s.udpRouter.localAddress())
ready := make(chan struct{})
runErrCh := make(chan error)
runErr = runErrCh
done := make(chan struct{})
done := make(chan error)
s.done = done
go s.runServer(ready, runErrCh, done)
stop := make(chan struct{})
s.stopCh = stop
go s.runServer(ready, runErrCh, stop, done)
select {
case <-ready:
case <-ctx.Done():
@@ -62,61 +73,90 @@ func (s *server) Start(ctx context.Context) (runErr <-chan error, err error) {
}
func (s *server) runServer(ready chan<- struct{},
runErrCh chan<- error, done chan<- struct{},
runErrCh chan<- error, stop <-chan struct{}, done chan<- error,
) {
close(ready)
defer close(done)
wg := new(sync.WaitGroup)
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.listener.Accept()
if err != nil {
if !s.stopping.Load() {
_ = s.stop()
runErrCh <- fmt.Errorf("accepting connection: %w", err)
}
return
}
wg.Add(1)
go func(ctx context.Context, connection net.Conn,
dialer *net.Dialer, wg *sync.WaitGroup,
) {
defer wg.Done()
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
logger: s.logger,
}
err := socksConn.run(ctx)
udpErrCh := make(chan error)
go func() {
udpErrCh <- s.udpRouter.run(s.socksConnCtx)
}()
tcpErrCh := make(chan error)
go func() {
var wg sync.WaitGroup
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.tcpListener.Accept()
if err != nil {
s.logger.Infof("running socks connection: %s", err)
s.socksConnCancel() // stop ongoing TCP socks connections - no impact on UDP
tcpErrCh <- fmt.Errorf("accepting connection: %w", err)
return
}
}(s.socksConnCtx, connection, dialer, wg)
wg.Go(func() {
connection := connection // capture loop variable
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
udpRouter: s.udpRouter,
logger: s.logger,
}
err := socksConn.run(s.socksConnCtx)
if err != nil {
s.logger.Infof("running socks connection: %s", err)
}
})
}
}()
select {
case <-stop:
s.listening.Store(false)
var errs []error
err := s.tcpListener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing TCP listener: %w", err))
}
// stop ongoing TCP socks connections. This impacts the udpRouter run error when it is being closed.
s.socksConnCancel()
<-tcpErrCh // wait for TCP server to stop
err = s.udpRouter.close()
if err != nil {
errs = append(errs, fmt.Errorf("closing UDP router: %w", err))
}
<-udpErrCh // wait for UDP router to stop
if len(errs) > 0 {
// Only write to the done channel if the [server.Stop] method is waiting to read from it
done <- errors.Join(errs...)
}
// If no error, the done channel is closed so the error is effectively `nil`
// Note: do NOT write an error the runError channel, since we are stopping the server gracefully.
case err := <-udpErrCh:
_ = s.tcpListener.Close() // stop accepting new TCP connections
s.socksConnCancel() // stop ongoing TCP socks connections
<-tcpErrCh // wait for TCP server to stop
runErrCh <- fmt.Errorf("running UDP router: %w", err)
case err := <-tcpErrCh:
s.socksConnCancel()
_ = s.udpRouter.close() // stop UDP router
<-udpErrCh // wait for UDP router to stop
runErrCh <- fmt.Errorf("running TCP server: %w", err)
}
}
func (s *server) Stop() (err error) {
s.stopping.Store(true)
err = s.stop()
<-s.done // wait for run goroutine to finish
s.stopping.Store(false)
return err
}
func (s *server) stop() error {
s.listening.Store(false)
err := s.listener.Close()
s.socksConnCancel() // stop ongoing socks connections
return err
close(s.stopCh)
return <-s.done
}
func (s *server) listeningAddress() net.Addr {
if s.listening.Load() {
return s.listener.Addr()
return s.tcpListener.Addr()
}
return nil
}
+249 -1
View File
@@ -10,6 +10,7 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
)
var (
@@ -23,6 +24,7 @@ type socksConn struct {
username string
password string
clientConn net.Conn
udpRouter *udpRouter
logger Logger
}
@@ -109,11 +111,29 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
if request.command != connect {
switch request.command {
case connect:
err = c.handleConnectRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
case udpAssociate:
err = c.handleUDPAssociateRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
default:
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
return fmt.Errorf("command %s is not supported", request.command)
}
}
func (c *socksConn) handleConnectRequest(ctx context.Context,
socksVersion byte, request request,
) error {
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
if err != nil {
@@ -176,6 +196,234 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
}
}
func (c *socksConn) handleUDPAssociateRequest(ctx context.Context,
socksVersion byte, request request,
) error {
expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported)
return fmt.Errorf("deriving expected client address and port from request: %w", err)
}
bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses()
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("getting udp association addresses: %w", err)
}
association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("registering udp association: %w", err)
}
defer c.udpRouter.unregisterAssociation(association)
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded,
bindAddrType, bindAddress, bindPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("writing successful %s response: %w", udpAssociate, err)
}
associationCtx, associationCancel := context.WithCancel(ctx)
defer associationCancel()
var wg sync.WaitGroup
wg.Go(func() {
c.udpRouter.runAssociationHandler(associationCtx, association)
})
wg.Go(func() {
_, _ = io.Copy(io.Discard, c.clientConn)
associationCancel()
})
<-associationCtx.Done()
wg.Wait()
return nil
}
func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) {
switch request.addressType {
case ipv4, ipv6:
expectedClientAddress, parseErr := netip.ParseAddr(request.destination)
if parseErr != nil {
return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr)
}
expectedClientAddress = expectedClientAddress.Unmap()
if !expectedClientAddress.IsUnspecified() {
return netip.AddrPortFrom(expectedClientAddress, request.port), nil
}
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
case domainName:
if request.destination != "" || request.port != 0 {
return netip.AddrPort{}, fmt.Errorf("domain name is not supported for UDP associate destination")
}
return netip.AddrPort{}, nil
default:
return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType)
}
}
func (c *socksConn) udpAssociationAddresses() (bindAddress string,
bindPort uint16, bindAddrType addrType, err error,
) {
localAddress := c.udpRouter.localAddress().String()
host, portString, err := net.SplitHostPort(localAddress)
if err != nil {
return "", 0, 0, fmt.Errorf("splitting local address: %w", err)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return "", 0, 0, fmt.Errorf("parsing local port: %w", err)
}
bindAddress = host
bindPort = uint16(port)
if isUnspecifiedIPAddress(bindAddress) {
controlLocalAddress := c.clientConn.LocalAddr().String()
controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress)
if splitErr != nil {
return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr)
}
bindAddress = controlLocalHost
}
ipAddress := net.ParseIP(bindAddress)
if ipAddress == nil {
bindAddrType = domainName
return bindAddress, bindPort, bindAddrType, nil
}
if ipAddress.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
return bindAddress, bindPort, bindAddrType, nil
}
func isUnspecifiedIPAddress(address string) bool {
ipAddress, err := netip.ParseAddr(address)
if err != nil {
return false
}
return ipAddress.IsUnspecified()
}
func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) {
const minimumPacketLength = 4
if len(packet) < minimumPacketLength {
return "", nil, fmt.Errorf("packet is too short: %d", len(packet))
}
if packet[0] != 0 || packet[1] != 0 {
return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1])
}
if packet[2] != 0 {
return "", nil, fmt.Errorf("fragmentation is not supported")
}
offset := 3
addressType := addrType(packet[offset])
offset++
switch addressType {
case ipv4:
const ipv4Length = 4
if len(packet) < offset+ipv4Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv4 address")
}
var ip [ipv4Length]byte
copy(ip[:], packet[offset:offset+ipv4Length])
destination = netip.AddrFrom4(ip).String()
offset += ipv4Length
case ipv6:
const ipv6Length = 16
if len(packet) < offset+ipv6Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv6 address")
}
var ip [ipv6Length]byte
copy(ip[:], packet[offset:offset+ipv6Length])
destination = netip.AddrFrom16(ip).String()
offset += ipv6Length
case domainName:
if len(packet) < offset+1 {
return "", nil, fmt.Errorf("packet is too short for domain name length")
}
domainNameLength := int(packet[offset])
offset++
if len(packet) < offset+domainNameLength+2 {
return "", nil, fmt.Errorf("packet is too short for domain name")
}
destination = string(packet[offset : offset+domainNameLength])
offset += domainNameLength
default:
return "", nil, fmt.Errorf("address type is not supported: %d", addressType)
}
port := binary.BigEndian.Uint16(packet[offset : offset+2])
destination = net.JoinHostPort(destination, fmt.Sprint(port))
offset += 2
payload = packet[offset:]
return destination, payload, nil
}
func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort,
payload []byte,
) error {
address := sourceAddrPort.Addr()
if !address.IsValid() {
return errors.New("source address is not valid")
}
err := writeUDPDatagramSourceAddress(writer, address)
if err != nil {
return fmt.Errorf("writing source address: %w", err)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port())
_, err = writer.Write(portBytes[:])
if err != nil {
return fmt.Errorf("writing destination port: %w", err)
}
_, err = writer.Write(payload)
if err != nil {
return fmt.Errorf("writing payload: %w", err)
}
return nil
}
func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error {
var addrType addrType
var addressBytes []byte
switch {
case address.Is4():
addrType = ipv4
array := address.As4()
addressBytes = array[:]
case address.Is6():
addrType = ipv6
array := address.As16()
addressBytes = array[:]
default:
return fmt.Errorf("address type is not supported: %v", address)
}
_, err := writer.Write([]byte{0, 0, 0, byte(addrType)})
if err != nil {
return fmt.Errorf("writing header: %w", err)
}
_, err = writer.Write(addressBytes)
if err != nil {
return fmt.Errorf("writing IP address: %w", err)
}
return nil
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
const headerLength = 2 // version + nMethods bytes
+437 -38
View File
@@ -2,7 +2,10 @@ package socks5
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
@@ -96,6 +99,178 @@ func TestServerProxy(t *testing.T) {
}
}
func TestServerProxyTCPAndUDPParallel(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
username string
password string
}{
"no_auth": {},
"with_auth": {
username: "user",
password: "pass",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
backendTCPListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
backendTCPConnChannel := make(chan net.Conn, 1)
go func() {
connection, err := backendTCPListener.Accept()
if err != nil {
return
}
backendTCPConnChannel <- connection
}()
backendUDPPacketConn, err := (&net.ListenConfig{}).ListenPacket(t.Context(), "udp", "127.0.0.1:0")
require.NoError(t, err)
server := newServer(Settings{
Username: testCase.username,
Password: testCase.password,
Address: "127.0.0.1:0",
Logger: noopLogger{},
})
_, err = server.Start(t.Context())
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Stop()
_ = backendTCPListener.Close()
_ = backendUDPPacketConn.Close()
})
clientTCPConn := dialSOCKS5(t, server.listeningAddress().String(),
backendTCPListener.Addr().String(), testCase.username, testCase.password)
defer clientTCPConn.Close()
backendTCPConn := <-backendTCPConnChannel
defer backendTCPConn.Close()
udpControlConn, clientUDPConn := dialSOCKS5UDPAssociate(t,
server.listeningAddress().String(), testCase.username, testCase.password)
defer udpControlConn.Close()
defer clientUDPConn.Close()
tcpErrCh := make(chan error, 1)
go func() {
tcpErrCh <- runTCPProxyRoundTrip(clientTCPConn, backendTCPConn)
}()
udpErrCh := make(chan error, 1)
go func() {
udpErrCh <- runUDPProxyRoundTrip(t.Context(), clientUDPConn, backendUDPPacketConn)
}()
err = <-tcpErrCh
require.NoError(t, err)
err = <-udpErrCh
require.NoError(t, err)
})
}
}
func runTCPProxyRoundTrip(clientTCPConn net.Conn, backendTCPConn net.Conn) error {
clientMessage := []byte("hello from client")
_, err := clientTCPConn.Write(clientMessage)
if err != nil {
return err
}
received := make([]byte, len(clientMessage))
_, err = io.ReadFull(backendTCPConn, received)
if err != nil {
return err
}
if !bytes.Equal(clientMessage, received) {
return errors.New("backend did not receive expected TCP payload")
}
backendMessage := []byte("hello from backend")
_, err = backendTCPConn.Write(backendMessage)
if err != nil {
return err
}
receivedByClient := make([]byte, len(backendMessage))
_, err = io.ReadFull(clientTCPConn, receivedByClient)
if err != nil {
return err
}
if !bytes.Equal(backendMessage, receivedByClient) {
return errors.New("client did not receive expected TCP payload")
}
return nil
}
func runUDPProxyRoundTrip(ctx context.Context, clientUDPConn *net.UDPConn, backendUDPPacketConn net.PacketConn) error {
udpPayload := []byte("hello from udp client")
udpRequest, err := makeSOCKS5UDPDatagram(backendUDPPacketConn.LocalAddr().String(), udpPayload)
if err != nil {
return err
}
_, err = clientUDPConn.Write(udpRequest)
if err != nil {
return err
}
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
err = backendUDPPacketConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on backend connection: %w", err)
}
}
const bufferSize = 512
backendReadBuffer := make([]byte, bufferSize)
packetLength, proxyAddress, err := backendUDPPacketConn.ReadFrom(backendReadBuffer)
if err != nil {
return err
}
if !bytes.Equal(udpPayload, backendReadBuffer[:packetLength]) {
return errors.New("backend did not receive expected UDP payload")
}
backendUDPReply := []byte("hello from udp backend")
_, err = backendUDPPacketConn.WriteTo(backendUDPReply, proxyAddress)
if err != nil {
return err
}
if hasDeadline {
err = clientUDPConn.SetReadDeadline(deadline)
if err != nil {
return fmt.Errorf("setting read deadline on client connection: %w", err)
}
}
udpResponseBuffer := make([]byte, 1024)
responseLength, err := clientUDPConn.Read(udpResponseBuffer)
if err != nil {
return err
}
destinationAddress, udpResponsePayload, err := parseSOCKS5UDPDatagram(udpResponseBuffer[:responseLength])
if err != nil {
return err
}
if !bytes.Equal(backendUDPReply, udpResponsePayload) {
return errors.New("client did not receive expected UDP payload")
}
if destinationAddress != backendUDPPacketConn.LocalAddr().String() {
return errors.New("udp response destination address mismatch")
}
return nil
}
// dialSOCKS5 performs the full SOCKS5 handshake (with optional username/password
// subnegotiation) and returns a connected net.Conn ready for data exchange.
func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string) net.Conn {
@@ -109,6 +284,55 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
conn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, conn, username, password)
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
_, err = readSOCKS5ResponseAddress(t, conn)
require.NoError(t, err)
return conn
}
func dialSOCKS5UDPAssociate(t *testing.T, proxyAddr, username, password string) (net.Conn, *net.UDPConn) {
t.Helper()
controlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", proxyAddr)
require.NoError(t, err)
negotiateSOCKS5(t, controlConn, username, password)
udpAssociateRequest := []byte{socks5Version, byte(udpAssociate), 0, byte(ipv4), 0, 0, 0, 0, 0, 0}
_, err = controlConn.Write(udpAssociateRequest)
require.NoError(t, err)
udpProxyAddress, err := readSOCKS5ResponseAddress(t, controlConn)
require.NoError(t, err)
udpProxyResolvedAddress, err := net.ResolveUDPAddr("udp", udpProxyAddress)
require.NoError(t, err)
udpConn, err := net.DialUDP("udp", nil, udpProxyResolvedAddress)
require.NoError(t, err)
return controlConn, udpConn
}
func negotiateSOCKS5(t *testing.T, conn net.Conn, username, password string) {
t.Helper()
var err error
var method authMethod
if username != "" || password != "" {
method = authUsernamePassword
@@ -138,45 +362,146 @@ func dialSOCKS5(t *testing.T, proxyAddr, targetAddr, username, password string)
require.Equal(t, authUsernamePasswordSubNegotiation1, subnegResp[0])
require.Equal(t, byte(0), subnegResp[1])
}
}
var connectRequest []byte
if ip := net.ParseIP(host).To4(); ip != nil {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(ipv4)}
connectRequest = append(connectRequest, ip...)
} else {
connectRequest = []byte{socks5Version, byte(connect), 0, byte(domainName), byte(len(host))}
connectRequest = append(connectRequest, []byte(host)...)
}
connectRequest = binary.BigEndian.AppendUint16(connectRequest, uint16(targetPort)) //nolint:gosec
_, err = conn.Write(connectRequest)
require.NoError(t, err)
func readSOCKS5ResponseAddress(t *testing.T, conn net.Conn) (address string, err error) {
t.Helper()
var responseHeader [4]byte
_, err = io.ReadFull(conn, responseHeader[:])
require.NoError(t, err)
require.Equal(t, socks5Version, responseHeader[0])
require.Equal(t, byte(succeeded), responseHeader[1])
// Consume BND.ADDR and BND.PORT (their values are irrelevant to the caller).
switch addrType(responseHeader[3]) {
case ipv4:
var addrPort [net.IPv4len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case ipv6:
var addrPort [net.IPv6len + 2]byte
_, err = io.ReadFull(conn, addrPort[:])
require.NoError(t, err)
case domainName:
var lenBuf [1]byte
_, err = io.ReadFull(conn, lenBuf[:])
require.NoError(t, err)
addrPort := make([]byte, int(lenBuf[0])+2)
_, err = io.ReadFull(conn, addrPort)
require.NoError(t, err)
if err != nil {
return "", err
}
if responseHeader[0] != socks5Version {
return "", errors.New("version mismatch")
}
if responseHeader[1] != byte(succeeded) {
return "", errors.New("request was not successful")
}
return conn
var host string
switch addrType(responseHeader[3]) {
case ipv4:
addressAndPort := make([]byte, net.IPv4len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv4len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv4len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
case ipv6:
addressAndPort := make([]byte, net.IPv6len+2)
_, err = io.ReadFull(conn, addressAndPort)
if err != nil {
return "", err
}
host = net.IP(addressAndPort[:net.IPv6len]).String()
port := binary.BigEndian.Uint16(addressAndPort[net.IPv6len:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
case domainName:
var lengthBuffer [1]byte
_, err = io.ReadFull(conn, lengthBuffer[:])
if err != nil {
return "", err
}
domainAndPort := make([]byte, int(lengthBuffer[0])+2)
_, err = io.ReadFull(conn, domainAndPort)
if err != nil {
return "", err
}
host = string(domainAndPort[:len(domainAndPort)-2])
port := binary.BigEndian.Uint16(domainAndPort[len(domainAndPort)-2:])
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil
default:
return "", errors.New("unknown address type")
}
}
func makeSOCKS5UDPDatagram(targetAddress string, payload []byte) ([]byte, error) {
host, portString, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
datagram := []byte{0, 0, 0}
ipAddress := net.ParseIP(host)
if ipAddress != nil {
if ipAddress.To4() != nil {
datagram = append(datagram, byte(ipv4))
datagram = append(datagram, ipAddress.To4()...)
} else {
datagram = append(datagram, byte(ipv6))
datagram = append(datagram, ipAddress.To16()...)
}
} else {
if len(host) > 255 {
return nil, errors.New("domain name too long")
}
datagram = append(datagram, byte(domainName), byte(len(host)))
datagram = append(datagram, []byte(host)...)
}
datagram = binary.BigEndian.AppendUint16(datagram, uint16(port))
datagram = append(datagram, payload...)
return datagram, nil
}
func parseSOCKS5UDPDatagram(datagram []byte) (destinationAddress string, payload []byte, err error) {
if len(datagram) < 4 {
return "", nil, errors.New("datagram too short")
}
if datagram[0] != 0 || datagram[1] != 0 {
return "", nil, errors.New("invalid reserved header")
}
if datagram[2] != 0 {
return "", nil, errors.New("fragments are not supported")
}
offset := 3
var host string
switch addrType(datagram[offset]) {
case ipv4:
offset++
if len(datagram) < offset+net.IPv4len+2 {
return "", nil, errors.New("datagram too short for IPv4")
}
host = net.IP(datagram[offset : offset+net.IPv4len]).String()
offset += net.IPv4len
case ipv6:
offset++
if len(datagram) < offset+net.IPv6len+2 {
return "", nil, errors.New("datagram too short for IPv6")
}
host = net.IP(datagram[offset : offset+net.IPv6len]).String()
offset += net.IPv6len
case domainName:
offset++
if len(datagram) < offset+1 {
return "", nil, errors.New("datagram too short for domain length")
}
domainLength := int(datagram[offset])
offset++
if len(datagram) < offset+domainLength+2 {
return "", nil, errors.New("datagram too short for domain")
}
host = string(datagram[offset : offset+domainLength])
offset += domainLength
default:
return "", nil, errors.New("unknown address type")
}
if len(datagram) < offset+2 {
return "", nil, errors.New("datagram too short for port")
}
port := binary.BigEndian.Uint16(datagram[offset : offset+2])
offset += 2
return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), datagram[offset:], nil
}
func Test_newServer(t *testing.T) {
@@ -224,7 +549,8 @@ func Test_Server_StartStop(t *testing.T) {
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
logger.EXPECT().Infof("SOCKS5 server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 TCP server listening on %s", gomock.Any())
logger.EXPECT().Infof("SOCKS5 UDP server listening on %s", gomock.Any())
server := newServer(Settings{
Address: "127.0.0.1:0",
@@ -598,10 +924,6 @@ func Test_cmdType_String(t *testing.T) {
cmd: connect,
expectedName: "connect",
},
"bind": {
cmd: bind,
expectedName: "bind",
},
"udp_associate": {
cmd: udpAssociate,
expectedName: "UDP associate",
@@ -620,3 +942,80 @@ func Test_cmdType_String(t *testing.T) {
})
}
}
func Test_socksConn_udpAssociationAddresses(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
routerAddress string
expectAddressFromConn bool
expectedAddress string
}{
"wildcard_router_address_uses_control_connection_local_ip": {
routerAddress: ":0",
expectAddressFromConn: true,
},
"concrete_router_address_is_kept": {
routerAddress: "127.0.0.1:0",
expectedAddress: "127.0.0.1",
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
router, err := newUDPRouter(t.Context(), testCase.routerAddress, noopLogger{})
require.NoError(t, err)
t.Cleanup(func() {
err := router.close()
assert.NoError(t, err)
})
controlListener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err)
})
acceptedConnCh := make(chan net.Conn, 1)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(t.Context(), "tcp", controlListener.Addr().String())
require.NoError(t, err)
defer clientControlConn.Close()
serverControlConn := <-acceptedConnCh
defer serverControlConn.Close()
socksConnection := &socksConn{
clientConn: clientControlConn,
udpRouter: router,
}
bindAddress, bindPort, bindAddrType, err := socksConnection.udpAssociationAddresses()
require.NoError(t, err)
if testCase.expectAddressFromConn {
clientLocalHost, _, err := net.SplitHostPort(clientControlConn.LocalAddr().String())
require.NoError(t, err)
assert.Equal(t, clientLocalHost, bindAddress)
} else {
assert.Equal(t, testCase.expectedAddress, bindAddress)
}
_, routerPortString, err := net.SplitHostPort(router.localAddress().String())
require.NoError(t, err)
routerPort, err := strconv.ParseUint(routerPortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint16(routerPort), bindPort)
assert.Equal(t, ipv4, bindAddrType)
})
}
}
+370
View File
@@ -0,0 +1,370 @@
package socks5
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
)
type udpAssociation struct {
id uint64
clientAddrPort netip.AddrPort
expectedAddrPort netip.AddrPort
controlConnAddr netip.Addr
packetCh chan *bytes.Buffer
}
type udpRouter struct {
logger Logger
listener net.PacketConn
mutex sync.Mutex
bufferPool sync.Pool
nextAssociationID uint64
clientAddrPortToAssociation map[netip.AddrPort]udpAssociation
clientIPToPendingAssociations map[netip.Addr][]udpAssociation
associationIDToClientAddrPort map[uint64]netip.AddrPort
}
const (
maxUDPPacketLength = 65535
maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2
pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead
)
func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) {
config := &net.ListenConfig{}
listener, err := config.ListenPacket(ctx, "udp", address)
if err != nil {
return nil, fmt.Errorf("UDP listening: %w", err)
}
return &udpRouter{
logger: logger,
listener: listener,
bufferPool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity))
},
},
nextAssociationID: 1,
clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation),
clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation),
associationIDToClientAddrPort: make(map[uint64]netip.AddrPort),
}, nil
}
func (r *udpRouter) localAddress() net.Addr {
return r.listener.LocalAddr()
}
func (r *udpRouter) close() error {
return r.listener.Close()
}
func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) {
controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String())
if err != nil {
return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err)
}
controlConnAddr := controlConnAddrPort.Addr().Unmap()
r.mutex.Lock()
defer r.mutex.Unlock()
const udpPacketChannelBuffer = 2
associationID := r.nextAssociationID
r.nextAssociationID++
association := udpAssociation{
id: associationID,
expectedAddrPort: expectedAddrPort,
controlConnAddr: controlConnAddr,
packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer),
}
if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 {
association.clientAddrPort = expectedAddrPort
r.clientAddrPortToAssociation[association.clientAddrPort] = association
r.associationIDToClientAddrPort[association.id] = association.clientAddrPort
return association, nil
}
pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr]
pendingAssociations = append(pendingAssociations, association)
r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations
return association, nil
}
func (r *udpRouter) unregisterAssociation(association udpAssociation) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id]
if hasClientAddress {
delete(r.associationIDToClientAddrPort, association.id)
delete(r.clientAddrPortToAssociation, clientAddrPort)
}
pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr]
for i, pendingAssociation := range pendingAssociations {
if pendingAssociation.id == association.id {
pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...)
break
}
}
if len(pendingAssociations) == 0 {
delete(r.clientIPToPendingAssociations, association.controlConnAddr)
} else {
r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations
}
}
func (r *udpRouter) run(ctx context.Context) error {
packetBuffer := make([]byte, maxUDPPacketLength)
for {
packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil && errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("reading UDP packet: %w", err)
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
_, err = buffer.Write(packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("buffering packet: %s", err)
continue
}
err = r.routePacket(sourceAddrPort, buffer)
if err != nil {
r.logger.Warnf("failed routing UDP packet: %s", err)
}
}
}
func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error {
r.mutex.Lock()
association, packetFromClient := r.findClientAssociation(sourceAddrPort)
r.mutex.Unlock()
if !packetFromClient {
r.bufferPool.Put(packet)
return nil
}
select {
case association.packetCh <- packet:
return nil
default:
r.bufferPool.Put(packet)
return errors.New("association packet queue full")
}
}
func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) (
association udpAssociation, ok bool,
) {
association, ok = r.clientAddrPortToAssociation[sourceAddrPort]
if ok {
return association, true
}
sourceAddr := sourceAddrPort.Addr()
pendingAssociations := r.clientIPToPendingAssociations[sourceAddr]
if len(pendingAssociations) == 0 {
return udpAssociation{}, false
}
index := -1
for i, pendingAssociation := range pendingAssociations {
if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) {
association = pendingAssociation
index = i
break
}
}
if index == -1 {
return udpAssociation{}, false
}
r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...)
if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 {
delete(r.clientIPToPendingAssociations, sourceAddr)
}
association.clientAddrPort = sourceAddrPort
r.clientAddrPortToAssociation[sourceAddrPort] = association
r.associationIDToClientAddrPort[association.id] = sourceAddrPort
return association, true
}
func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool {
switch {
case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr():
return false
case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port():
return false
}
return true
}
func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) (
clientAddrPort netip.AddrPort, ok bool,
) {
r.mutex.Lock()
defer r.mutex.Unlock()
clientAddrPort, ok = r.associationIDToClientAddrPort[associationID]
return clientAddrPort, ok
}
func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) {
config := &net.ListenConfig{}
socket, err := config.ListenPacket(ctx, "udp", ":0")
if err != nil {
r.logger.Warnf("creating per-association UDP socket: %s", err)
return
}
defer socket.Close()
go closeSocketOnContextDone(ctx, socket)
packetBuffer := make([]byte, maxUDPPacketLength)
forwardDoneCh := make(chan struct{})
go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh)
for {
packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer)
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
<-forwardDoneCh
return
}
r.logger.Warnf("reading from per-association UDP socket: %s", err)
continue
}
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
if err != nil {
r.logger.Warnf("parsing source address from destination: %s", err)
continue
}
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
buffer.Reset()
err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength])
if err != nil {
r.bufferPool.Put(buffer)
r.logger.Warnf("encoding response datagram: %s", err)
continue
}
clientAddrPort, found := r.clientAddrPortForAssociation(association.id)
if !found {
r.bufferPool.Put(buffer)
r.logger.Warnf("client address not found for association id %d", association.id)
continue
}
clientUDPAddress := &net.UDPAddr{
IP: clientAddrPort.Addr().AsSlice(),
Port: int(clientAddrPort.Port()),
}
_, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("writing response to client: %s", err)
}
}
}
func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) {
<-ctx.Done()
_ = socket.Close()
}
func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn,
packetCh <-chan *bytes.Buffer, done chan<- struct{},
) {
defer close(done)
for {
select {
case <-ctx.Done():
return
case buffer, ok := <-packetCh:
if !ok {
return
}
err := r.writeClientPacketToDestination(ctx, socket, buffer)
r.bufferPool.Put(buffer)
if err != nil {
r.logger.Warnf("forwarding client packet to destination: %s", err)
}
}
}
}
func (r *udpRouter) writeClientPacketToDestination(ctx context.Context,
socket net.PacketConn, packet *bytes.Buffer,
) error {
destination, payload, err := decodeUDPDatagram(packet.Bytes())
if err != nil {
return fmt.Errorf("decoding UDP datagram: %w", err)
}
host, portStr, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Errorf("splitting destination host and port: %w", err)
}
if _, err := netip.ParseAddr(host); err != nil { // domain name
addrs, err := net.DefaultResolver.LookupHost(ctx, host)
if err != nil {
return fmt.Errorf("resolving destination host: %w", err)
}
if len(addrs) == 0 {
return fmt.Errorf("resolving destination host: no addresses found for %q", host)
}
destination = net.JoinHostPort(addrs[0], portStr)
}
resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination)
if err != nil {
return fmt.Errorf("resolving destination UDP address: %w", err)
}
_, err = socket.WriteTo(payload, resolvedDestinationUDPAddress)
if err != nil && ctx.Err() == nil {
return fmt.Errorf("writing payload to destination: %w", err)
}
return nil
}
func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) {
addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err)
}
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
@@ -0,0 +1,162 @@
//go:build integration
package socks5
import (
"bytes"
"context"
"math/rand/v2"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_udpRouter_ResolveGithubFromCloudflareDNS(t *testing.T) {
ctx := t.Context()
var cancel context.CancelFunc
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
const deadlineBuffer = 500 * time.Millisecond
deadline = deadline.Add(-deadlineBuffer)
} else {
const defaultTimeout = 10 * time.Second
deadline = time.Now().Add(defaultTimeout)
}
ctx, cancel = context.WithDeadline(ctx, deadline)
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
router, err := newUDPRouter(ctx, "127.0.0.1:0", logger)
require.NoError(t, err)
routerRunErrCh := make(chan error)
go func() {
routerRunErrCh <- router.run(ctx)
}()
t.Cleanup(func() {
cancel()
err := router.close()
assert.NoError(t, err, "closing router")
runErr := <-routerRunErrCh
assert.NoError(t, runErr)
})
controlListener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
err := controlListener.Close()
assert.NoError(t, err, "closing control listener")
})
acceptedConnCh := make(chan net.Conn)
go func() {
acceptedConn, acceptErr := controlListener.Accept()
assert.NoError(t, acceptErr, "accepting control connection")
if acceptErr != nil {
return
}
acceptedConnCh <- acceptedConn
}()
clientControlConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", controlListener.Addr().String())
require.NoError(t, err)
t.Cleanup(func() {
err = clientControlConn.Close()
assert.NoError(t, err, "closing client control connection")
})
serverControlConn := <-acceptedConnCh
t.Cleanup(func() {
err := serverControlConn.Close()
assert.NoError(t, err, "closing server control connection")
})
association, err := router.registerAssociation(serverControlConn, netip.AddrPort{})
require.NoError(t, err)
t.Cleanup(func() {
router.unregisterAssociation(association)
})
associationCtx, associationCancel := context.WithCancel(ctx)
handlerDoneCh := make(chan struct{})
go func() {
router.runAssociationHandler(associationCtx, association)
close(handlerDoneCh)
}()
t.Cleanup(func() {
associationCancel()
<-handlerDoneCh
})
udpRouterAddress, err := net.ResolveUDPAddr("udp", router.localAddress().String())
require.NoError(t, err)
clientUDPConn, err := net.DialUDP("udp", nil, udpRouterAddress)
require.NoError(t, err)
t.Cleanup(func() {
err := clientUDPConn.Close()
assert.NoError(t, err, "closing client UDP connection")
})
queryID := uint16(rand.Uint())
dnsRequest := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: queryID,
RecursionDesired: true,
},
Question: []dns.Question{{
Name: dns.Fqdn("github.com"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
dnsQuery, err := dnsRequest.Pack()
require.NoError(t, err)
targetAddrPort := netip.MustParseAddrPort("1.1.1.1:53")
socksDatagramBuffer := bytes.NewBuffer(nil)
err = encodeUDPDatagramToBuffer(socksDatagramBuffer, targetAddrPort, dnsQuery)
require.NoError(t, err)
socksDatagram := socksDatagramBuffer.Bytes()
err = clientUDPConn.SetDeadline(deadline)
require.NoError(t, err)
_, err = clientUDPConn.Write(socksDatagram)
require.NoError(t, err)
responseBuffer := make([]byte, maxUDPPacketLength)
responseLength, err := clientUDPConn.Read(responseBuffer)
require.NoError(t, err)
responseDestination, responsePayload, err := decodeUDPDatagram(responseBuffer[:responseLength])
require.NoError(t, err)
responseHost, responsePortString, err := net.SplitHostPort(responseDestination)
require.NoError(t, err)
responsePort, err := strconv.ParseUint(responsePortString, 10, 16)
require.NoError(t, err)
assert.Equal(t, uint64(53), responsePort)
assert.NotEmpty(t, responseHost)
dnsResponse := new(dns.Msg)
err = dnsResponse.Unpack(responsePayload)
require.NoError(t, err)
assert.Equal(t, queryID, dnsResponse.Id)
assert.True(t, dnsResponse.Response)
assert.Equal(t, dns.RcodeSuccess, dnsResponse.Rcode)
require.NotEmpty(t, dnsResponse.Question)
assert.Equal(t, dns.Fqdn("github.com"), dnsResponse.Question[0].Name)
assert.Equal(t, uint16(dns.TypeA), dnsResponse.Question[0].Qtype)
assert.NotEmpty(t, dnsResponse.Answer)
require.NoError(t, err)
}