feat(socks5): UDP proxying (#3353)

This commit is contained in:
Quentin McGaw
2026-06-11 09:32:38 -04:00
committed by GitHub
parent acab89b91a
commit 6d84462f00
10 changed files with 1311 additions and 95 deletions
+249 -1
View File
@@ -10,6 +10,7 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
)
var (
@@ -23,6 +24,7 @@ type socksConn struct {
username string
password string
clientConn net.Conn
udpRouter *udpRouter
logger Logger
}
@@ -109,11 +111,29 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
if request.command != connect {
switch request.command {
case connect:
err = c.handleConnectRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
case udpAssociate:
err = c.handleUDPAssociateRequest(ctx, socksVersion, request)
if err != nil {
return fmt.Errorf("handling %s request: %w", request.command, err)
}
return nil
default:
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
return fmt.Errorf("command %s is not supported", request.command)
}
}
func (c *socksConn) handleConnectRequest(ctx context.Context,
socksVersion byte, request request,
) error {
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
if err != nil {
@@ -176,6 +196,234 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
}
}
func (c *socksConn) handleUDPAssociateRequest(ctx context.Context,
socksVersion byte, request request,
) error {
expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported)
return fmt.Errorf("deriving expected client address and port from request: %w", err)
}
bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses()
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("getting udp association addresses: %w", err)
}
association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("registering udp association: %w", err)
}
defer c.udpRouter.unregisterAssociation(association)
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded,
bindAddrType, bindAddress, bindPort)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("writing successful %s response: %w", udpAssociate, err)
}
associationCtx, associationCancel := context.WithCancel(ctx)
defer associationCancel()
var wg sync.WaitGroup
wg.Go(func() {
c.udpRouter.runAssociationHandler(associationCtx, association)
})
wg.Go(func() {
_, _ = io.Copy(io.Discard, c.clientConn)
associationCancel()
})
<-associationCtx.Done()
wg.Wait()
return nil
}
func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) {
switch request.addressType {
case ipv4, ipv6:
expectedClientAddress, parseErr := netip.ParseAddr(request.destination)
if parseErr != nil {
return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr)
}
expectedClientAddress = expectedClientAddress.Unmap()
if !expectedClientAddress.IsUnspecified() {
return netip.AddrPortFrom(expectedClientAddress, request.port), nil
}
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
case domainName:
if request.destination != "" || request.port != 0 {
return netip.AddrPort{}, fmt.Errorf("domain name is not supported for UDP associate destination")
}
return netip.AddrPort{}, nil
default:
return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType)
}
}
func (c *socksConn) udpAssociationAddresses() (bindAddress string,
bindPort uint16, bindAddrType addrType, err error,
) {
localAddress := c.udpRouter.localAddress().String()
host, portString, err := net.SplitHostPort(localAddress)
if err != nil {
return "", 0, 0, fmt.Errorf("splitting local address: %w", err)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return "", 0, 0, fmt.Errorf("parsing local port: %w", err)
}
bindAddress = host
bindPort = uint16(port)
if isUnspecifiedIPAddress(bindAddress) {
controlLocalAddress := c.clientConn.LocalAddr().String()
controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress)
if splitErr != nil {
return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr)
}
bindAddress = controlLocalHost
}
ipAddress := net.ParseIP(bindAddress)
if ipAddress == nil {
bindAddrType = domainName
return bindAddress, bindPort, bindAddrType, nil
}
if ipAddress.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
return bindAddress, bindPort, bindAddrType, nil
}
func isUnspecifiedIPAddress(address string) bool {
ipAddress, err := netip.ParseAddr(address)
if err != nil {
return false
}
return ipAddress.IsUnspecified()
}
func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) {
const minimumPacketLength = 4
if len(packet) < minimumPacketLength {
return "", nil, fmt.Errorf("packet is too short: %d", len(packet))
}
if packet[0] != 0 || packet[1] != 0 {
return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1])
}
if packet[2] != 0 {
return "", nil, fmt.Errorf("fragmentation is not supported")
}
offset := 3
addressType := addrType(packet[offset])
offset++
switch addressType {
case ipv4:
const ipv4Length = 4
if len(packet) < offset+ipv4Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv4 address")
}
var ip [ipv4Length]byte
copy(ip[:], packet[offset:offset+ipv4Length])
destination = netip.AddrFrom4(ip).String()
offset += ipv4Length
case ipv6:
const ipv6Length = 16
if len(packet) < offset+ipv6Length+2 {
return "", nil, fmt.Errorf("packet is too short for IPv6 address")
}
var ip [ipv6Length]byte
copy(ip[:], packet[offset:offset+ipv6Length])
destination = netip.AddrFrom16(ip).String()
offset += ipv6Length
case domainName:
if len(packet) < offset+1 {
return "", nil, fmt.Errorf("packet is too short for domain name length")
}
domainNameLength := int(packet[offset])
offset++
if len(packet) < offset+domainNameLength+2 {
return "", nil, fmt.Errorf("packet is too short for domain name")
}
destination = string(packet[offset : offset+domainNameLength])
offset += domainNameLength
default:
return "", nil, fmt.Errorf("address type is not supported: %d", addressType)
}
port := binary.BigEndian.Uint16(packet[offset : offset+2])
destination = net.JoinHostPort(destination, fmt.Sprint(port))
offset += 2
payload = packet[offset:]
return destination, payload, nil
}
func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort,
payload []byte,
) error {
address := sourceAddrPort.Addr()
if !address.IsValid() {
return errors.New("source address is not valid")
}
err := writeUDPDatagramSourceAddress(writer, address)
if err != nil {
return fmt.Errorf("writing source address: %w", err)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port())
_, err = writer.Write(portBytes[:])
if err != nil {
return fmt.Errorf("writing destination port: %w", err)
}
_, err = writer.Write(payload)
if err != nil {
return fmt.Errorf("writing payload: %w", err)
}
return nil
}
func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error {
var addrType addrType
var addressBytes []byte
switch {
case address.Is4():
addrType = ipv4
array := address.As4()
addressBytes = array[:]
case address.Is6():
addrType = ipv6
array := address.As16()
addressBytes = array[:]
default:
return fmt.Errorf("address type is not supported: %v", address)
}
_, err := writer.Write([]byte{0, 0, 0, byte(addrType)})
if err != nil {
return fmt.Errorf("writing header: %w", err)
}
_, err = writer.Write(addressBytes)
if err != nil {
return fmt.Errorf("writing IP address: %w", err)
}
return nil
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
const headerLength = 2 // version + nMethods bytes