package socks5 import ( "bytes" "context" "errors" "fmt" "net" "net/netip" "sync" ) type udpAssociation struct { id uint64 clientAddrPort netip.AddrPort expectedAddrPort netip.AddrPort controlConnAddr netip.Addr packetCh chan *bytes.Buffer } type udpRouter struct { logger Logger listener net.PacketConn mutex sync.Mutex bufferPool sync.Pool nextAssociationID uint64 clientAddrPortToAssociation map[netip.AddrPort]udpAssociation clientIPToPendingAssociations map[netip.Addr][]udpAssociation associationIDToClientAddrPort map[uint64]netip.AddrPort } const ( maxUDPPacketLength = 65535 maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2 pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead ) func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) { config := &net.ListenConfig{} listener, err := config.ListenPacket(ctx, "udp", address) if err != nil { return nil, fmt.Errorf("UDP listening: %w", err) } return &udpRouter{ logger: logger, listener: listener, bufferPool: sync.Pool{ New: func() any { return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity)) }, }, nextAssociationID: 1, clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation), clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation), associationIDToClientAddrPort: make(map[uint64]netip.AddrPort), }, nil } func (r *udpRouter) localAddress() net.Addr { return r.listener.LocalAddr() } func (r *udpRouter) close() error { return r.listener.Close() } func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) { controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String()) if err != nil { return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err) } controlConnAddr := controlConnAddrPort.Addr().Unmap() r.mutex.Lock() defer r.mutex.Unlock() const udpPacketChannelBuffer = 2 associationID := r.nextAssociationID r.nextAssociationID++ association := udpAssociation{ id: associationID, expectedAddrPort: expectedAddrPort, controlConnAddr: controlConnAddr, packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer), } if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 { association.clientAddrPort = expectedAddrPort r.clientAddrPortToAssociation[association.clientAddrPort] = association r.associationIDToClientAddrPort[association.id] = association.clientAddrPort return association, nil } pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr] pendingAssociations = append(pendingAssociations, association) r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations return association, nil } func (r *udpRouter) unregisterAssociation(association udpAssociation) { r.mutex.Lock() defer r.mutex.Unlock() clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id] if hasClientAddress { delete(r.associationIDToClientAddrPort, association.id) delete(r.clientAddrPortToAssociation, clientAddrPort) } pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr] for i, pendingAssociation := range pendingAssociations { if pendingAssociation.id == association.id { pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...) break } } if len(pendingAssociations) == 0 { delete(r.clientIPToPendingAssociations, association.controlConnAddr) } else { r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations } } func (r *udpRouter) run(ctx context.Context) error { packetBuffer := make([]byte, maxUDPPacketLength) for { packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer) if err != nil { if ctx.Err() != nil && errors.Is(err, net.ErrClosed) { return nil } return fmt.Errorf("reading UDP packet: %w", err) } sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress) if err != nil { r.logger.Warnf("parsing source address: %s", err) continue } buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert buffer.Reset() _, err = buffer.Write(packetBuffer[:packetLength]) if err != nil { r.bufferPool.Put(buffer) r.logger.Warnf("buffering packet: %s", err) continue } err = r.routePacket(sourceAddrPort, buffer) if err != nil { r.logger.Warnf("failed routing UDP packet: %s", err) } } } func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error { r.mutex.Lock() association, packetFromClient := r.findClientAssociation(sourceAddrPort) r.mutex.Unlock() if !packetFromClient { r.bufferPool.Put(packet) return nil } select { case association.packetCh <- packet: return nil default: r.bufferPool.Put(packet) return errors.New("association packet queue full") } } func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) ( association udpAssociation, ok bool, ) { association, ok = r.clientAddrPortToAssociation[sourceAddrPort] if ok { return association, true } sourceAddr := sourceAddrPort.Addr() pendingAssociations := r.clientIPToPendingAssociations[sourceAddr] if len(pendingAssociations) == 0 { return udpAssociation{}, false } index := -1 for i, pendingAssociation := range pendingAssociations { if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) { association = pendingAssociation index = i break } } if index == -1 { return udpAssociation{}, false } r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...) if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 { delete(r.clientIPToPendingAssociations, sourceAddr) } association.clientAddrPort = sourceAddrPort r.clientAddrPortToAssociation[sourceAddrPort] = association r.associationIDToClientAddrPort[association.id] = sourceAddrPort return association, true } func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool { switch { case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr(): return false case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port(): return false } return true } func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) ( clientAddrPort netip.AddrPort, ok bool, ) { r.mutex.Lock() defer r.mutex.Unlock() clientAddrPort, ok = r.associationIDToClientAddrPort[associationID] return clientAddrPort, ok } func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) { config := &net.ListenConfig{} socket, err := config.ListenPacket(ctx, "udp", ":0") if err != nil { r.logger.Warnf("creating per-association UDP socket: %s", err) return } defer socket.Close() go closeSocketOnContextDone(ctx, socket) packetBuffer := make([]byte, maxUDPPacketLength) forwardDoneCh := make(chan struct{}) go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh) for { packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer) if err != nil { if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { <-forwardDoneCh return } r.logger.Warnf("reading from per-association UDP socket: %s", err) continue } sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress) if err != nil { r.logger.Warnf("parsing source address from destination: %s", err) continue } buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert buffer.Reset() err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength]) if err != nil { r.bufferPool.Put(buffer) r.logger.Warnf("encoding response datagram: %s", err) continue } clientAddrPort, found := r.clientAddrPortForAssociation(association.id) if !found { r.bufferPool.Put(buffer) r.logger.Warnf("client address not found for association id %d", association.id) continue } clientUDPAddress := &net.UDPAddr{ IP: clientAddrPort.Addr().AsSlice(), Port: int(clientAddrPort.Port()), } _, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress) r.bufferPool.Put(buffer) if err != nil { r.logger.Warnf("writing response to client: %s", err) } } } func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) { <-ctx.Done() _ = socket.Close() } func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn, packetCh <-chan *bytes.Buffer, done chan<- struct{}, ) { defer close(done) for { select { case <-ctx.Done(): return case buffer, ok := <-packetCh: if !ok { return } err := r.writeClientPacketToDestination(ctx, socket, buffer) r.bufferPool.Put(buffer) if err != nil { r.logger.Warnf("forwarding client packet to destination: %s", err) } } } } func (r *udpRouter) writeClientPacketToDestination(ctx context.Context, socket net.PacketConn, packet *bytes.Buffer, ) error { destination, payload, err := decodeUDPDatagram(packet.Bytes()) if err != nil { return fmt.Errorf("decoding UDP datagram: %w", err) } host, portStr, err := net.SplitHostPort(destination) if err != nil { return fmt.Errorf("splitting destination host and port: %w", err) } if _, err := netip.ParseAddr(host); err != nil { // domain name addrs, err := net.DefaultResolver.LookupHost(ctx, host) if err != nil { return fmt.Errorf("resolving destination host: %w", err) } if len(addrs) == 0 { return fmt.Errorf("resolving destination host: no addresses found for %q", host) } destination = net.JoinHostPort(addrs[0], portStr) } resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination) if err != nil { return fmt.Errorf("resolving destination UDP address: %w", err) } _, err = socket.WriteTo(payload, resolvedDestinationUDPAddress) if err != nil && ctx.Err() == nil { return fmt.Errorf("writing payload to destination: %w", err) } return nil } func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) { addrPort, err := netip.ParseAddrPort(addr.String()) if err != nil { return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err) } return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil }