mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-20 02:14:08 +02:00
feat(socks5): UDP proxying (#3353)
This commit is contained in:
@@ -0,0 +1,370 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type udpAssociation struct {
|
||||
id uint64
|
||||
clientAddrPort netip.AddrPort
|
||||
expectedAddrPort netip.AddrPort
|
||||
controlConnAddr netip.Addr
|
||||
packetCh chan *bytes.Buffer
|
||||
}
|
||||
|
||||
type udpRouter struct {
|
||||
logger Logger
|
||||
|
||||
listener net.PacketConn
|
||||
mutex sync.Mutex
|
||||
bufferPool sync.Pool
|
||||
nextAssociationID uint64
|
||||
clientAddrPortToAssociation map[netip.AddrPort]udpAssociation
|
||||
clientIPToPendingAssociations map[netip.Addr][]udpAssociation
|
||||
associationIDToClientAddrPort map[uint64]netip.AddrPort
|
||||
}
|
||||
|
||||
const (
|
||||
maxUDPPacketLength = 65535
|
||||
maxSOCKS5UDPDatagramOverhead = 3 + 1 + 16 + 2
|
||||
pooledUDPPacketBufferCapacity = maxUDPPacketLength + maxSOCKS5UDPDatagramOverhead
|
||||
)
|
||||
|
||||
func newUDPRouter(ctx context.Context, address string, logger Logger) (router *udpRouter, err error) {
|
||||
config := &net.ListenConfig{}
|
||||
listener, err := config.ListenPacket(ctx, "udp", address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("UDP listening: %w", err)
|
||||
}
|
||||
|
||||
return &udpRouter{
|
||||
logger: logger,
|
||||
listener: listener,
|
||||
bufferPool: sync.Pool{
|
||||
New: func() any {
|
||||
return bytes.NewBuffer(make([]byte, 0, pooledUDPPacketBufferCapacity))
|
||||
},
|
||||
},
|
||||
nextAssociationID: 1,
|
||||
clientAddrPortToAssociation: make(map[netip.AddrPort]udpAssociation),
|
||||
clientIPToPendingAssociations: make(map[netip.Addr][]udpAssociation),
|
||||
associationIDToClientAddrPort: make(map[uint64]netip.AddrPort),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *udpRouter) localAddress() net.Addr {
|
||||
return r.listener.LocalAddr()
|
||||
}
|
||||
|
||||
func (r *udpRouter) close() error {
|
||||
return r.listener.Close()
|
||||
}
|
||||
|
||||
func (r *udpRouter) registerAssociation(controlConn net.Conn, expectedAddrPort netip.AddrPort) (udpAssociation, error) {
|
||||
controlConnAddrPort, err := netip.ParseAddrPort(controlConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return udpAssociation{}, fmt.Errorf("parsing control connection address: %w", err)
|
||||
}
|
||||
controlConnAddr := controlConnAddrPort.Addr().Unmap()
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
const udpPacketChannelBuffer = 2
|
||||
associationID := r.nextAssociationID
|
||||
r.nextAssociationID++
|
||||
|
||||
association := udpAssociation{
|
||||
id: associationID,
|
||||
expectedAddrPort: expectedAddrPort,
|
||||
controlConnAddr: controlConnAddr,
|
||||
packetCh: make(chan *bytes.Buffer, udpPacketChannelBuffer),
|
||||
}
|
||||
|
||||
if expectedAddrPort.Addr().IsValid() && expectedAddrPort.Port() != 0 {
|
||||
association.clientAddrPort = expectedAddrPort
|
||||
r.clientAddrPortToAssociation[association.clientAddrPort] = association
|
||||
r.associationIDToClientAddrPort[association.id] = association.clientAddrPort
|
||||
return association, nil
|
||||
}
|
||||
|
||||
pendingAssociations := r.clientIPToPendingAssociations[controlConnAddr]
|
||||
pendingAssociations = append(pendingAssociations, association)
|
||||
r.clientIPToPendingAssociations[controlConnAddr] = pendingAssociations
|
||||
|
||||
return association, nil
|
||||
}
|
||||
|
||||
func (r *udpRouter) unregisterAssociation(association udpAssociation) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
clientAddrPort, hasClientAddress := r.associationIDToClientAddrPort[association.id]
|
||||
if hasClientAddress {
|
||||
delete(r.associationIDToClientAddrPort, association.id)
|
||||
delete(r.clientAddrPortToAssociation, clientAddrPort)
|
||||
}
|
||||
|
||||
pendingAssociations := r.clientIPToPendingAssociations[association.controlConnAddr]
|
||||
for i, pendingAssociation := range pendingAssociations {
|
||||
if pendingAssociation.id == association.id {
|
||||
pendingAssociations = append(pendingAssociations[:i], pendingAssociations[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(pendingAssociations) == 0 {
|
||||
delete(r.clientIPToPendingAssociations, association.controlConnAddr)
|
||||
} else {
|
||||
r.clientIPToPendingAssociations[association.controlConnAddr] = pendingAssociations
|
||||
}
|
||||
}
|
||||
|
||||
func (r *udpRouter) run(ctx context.Context) error {
|
||||
packetBuffer := make([]byte, maxUDPPacketLength)
|
||||
|
||||
for {
|
||||
packetLength, sourceAddress, err := r.listener.ReadFrom(packetBuffer)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil && errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading UDP packet: %w", err)
|
||||
}
|
||||
|
||||
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
|
||||
if err != nil {
|
||||
r.logger.Warnf("parsing source address: %s", err)
|
||||
continue
|
||||
}
|
||||
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
|
||||
buffer.Reset()
|
||||
_, err = buffer.Write(packetBuffer[:packetLength])
|
||||
if err != nil {
|
||||
r.bufferPool.Put(buffer)
|
||||
r.logger.Warnf("buffering packet: %s", err)
|
||||
continue
|
||||
}
|
||||
err = r.routePacket(sourceAddrPort, buffer)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed routing UDP packet: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *udpRouter) routePacket(sourceAddrPort netip.AddrPort, packet *bytes.Buffer) error {
|
||||
r.mutex.Lock()
|
||||
association, packetFromClient := r.findClientAssociation(sourceAddrPort)
|
||||
r.mutex.Unlock()
|
||||
|
||||
if !packetFromClient {
|
||||
r.bufferPool.Put(packet)
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case association.packetCh <- packet:
|
||||
return nil
|
||||
default:
|
||||
r.bufferPool.Put(packet)
|
||||
return errors.New("association packet queue full")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *udpRouter) findClientAssociation(sourceAddrPort netip.AddrPort) (
|
||||
association udpAssociation, ok bool,
|
||||
) {
|
||||
association, ok = r.clientAddrPortToAssociation[sourceAddrPort]
|
||||
if ok {
|
||||
return association, true
|
||||
}
|
||||
sourceAddr := sourceAddrPort.Addr()
|
||||
|
||||
pendingAssociations := r.clientIPToPendingAssociations[sourceAddr]
|
||||
if len(pendingAssociations) == 0 {
|
||||
return udpAssociation{}, false
|
||||
}
|
||||
|
||||
index := -1
|
||||
for i, pendingAssociation := range pendingAssociations {
|
||||
if matchesExpectedClientEndpoint(pendingAssociation, sourceAddrPort) {
|
||||
association = pendingAssociation
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index == -1 {
|
||||
return udpAssociation{}, false
|
||||
}
|
||||
|
||||
r.clientIPToPendingAssociations[sourceAddr] = append(pendingAssociations[:index], pendingAssociations[index+1:]...)
|
||||
if len(r.clientIPToPendingAssociations[sourceAddr]) == 0 {
|
||||
delete(r.clientIPToPendingAssociations, sourceAddr)
|
||||
}
|
||||
|
||||
association.clientAddrPort = sourceAddrPort
|
||||
r.clientAddrPortToAssociation[sourceAddrPort] = association
|
||||
r.associationIDToClientAddrPort[association.id] = sourceAddrPort
|
||||
|
||||
return association, true
|
||||
}
|
||||
|
||||
func matchesExpectedClientEndpoint(association udpAssociation, sourceAddrPort netip.AddrPort) bool {
|
||||
switch {
|
||||
case association.expectedAddrPort.Addr().IsValid() && sourceAddrPort.Addr() != association.expectedAddrPort.Addr():
|
||||
return false
|
||||
case association.expectedAddrPort.Port() != 0 && sourceAddrPort.Port() != association.expectedAddrPort.Port():
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *udpRouter) clientAddrPortForAssociation(associationID uint64) (
|
||||
clientAddrPort netip.AddrPort, ok bool,
|
||||
) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
clientAddrPort, ok = r.associationIDToClientAddrPort[associationID]
|
||||
return clientAddrPort, ok
|
||||
}
|
||||
|
||||
func (r *udpRouter) runAssociationHandler(ctx context.Context, association udpAssociation) {
|
||||
config := &net.ListenConfig{}
|
||||
socket, err := config.ListenPacket(ctx, "udp", ":0")
|
||||
if err != nil {
|
||||
r.logger.Warnf("creating per-association UDP socket: %s", err)
|
||||
return
|
||||
}
|
||||
defer socket.Close()
|
||||
|
||||
go closeSocketOnContextDone(ctx, socket)
|
||||
|
||||
packetBuffer := make([]byte, maxUDPPacketLength)
|
||||
|
||||
forwardDoneCh := make(chan struct{})
|
||||
go r.forwardClientPackets(ctx, socket, association.packetCh, forwardDoneCh)
|
||||
|
||||
for {
|
||||
packetLength, sourceAddress, err := socket.ReadFrom(packetBuffer)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
<-forwardDoneCh
|
||||
return
|
||||
}
|
||||
r.logger.Warnf("reading from per-association UDP socket: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
sourceAddrPort, err := netAddrToNetipAddrPort(sourceAddress)
|
||||
if err != nil {
|
||||
r.logger.Warnf("parsing source address from destination: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
buffer := r.bufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert
|
||||
buffer.Reset()
|
||||
err = encodeUDPDatagramToBuffer(buffer, sourceAddrPort, packetBuffer[:packetLength])
|
||||
if err != nil {
|
||||
r.bufferPool.Put(buffer)
|
||||
r.logger.Warnf("encoding response datagram: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
clientAddrPort, found := r.clientAddrPortForAssociation(association.id)
|
||||
if !found {
|
||||
r.bufferPool.Put(buffer)
|
||||
r.logger.Warnf("client address not found for association id %d", association.id)
|
||||
continue
|
||||
}
|
||||
|
||||
clientUDPAddress := &net.UDPAddr{
|
||||
IP: clientAddrPort.Addr().AsSlice(),
|
||||
Port: int(clientAddrPort.Port()),
|
||||
}
|
||||
_, err = r.listener.WriteTo(buffer.Bytes(), clientUDPAddress)
|
||||
r.bufferPool.Put(buffer)
|
||||
if err != nil {
|
||||
r.logger.Warnf("writing response to client: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func closeSocketOnContextDone(ctx context.Context, socket net.PacketConn) {
|
||||
<-ctx.Done()
|
||||
_ = socket.Close()
|
||||
}
|
||||
|
||||
func (r *udpRouter) forwardClientPackets(ctx context.Context, socket net.PacketConn,
|
||||
packetCh <-chan *bytes.Buffer, done chan<- struct{},
|
||||
) {
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case buffer, ok := <-packetCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := r.writeClientPacketToDestination(ctx, socket, buffer)
|
||||
r.bufferPool.Put(buffer)
|
||||
if err != nil {
|
||||
r.logger.Warnf("forwarding client packet to destination: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *udpRouter) writeClientPacketToDestination(ctx context.Context,
|
||||
socket net.PacketConn, packet *bytes.Buffer,
|
||||
) error {
|
||||
destination, payload, err := decodeUDPDatagram(packet.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding UDP datagram: %w", err)
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("splitting destination host and port: %w", err)
|
||||
}
|
||||
|
||||
if _, err := netip.ParseAddr(host); err != nil { // domain name
|
||||
addrs, err := net.DefaultResolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving destination host: %w", err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return fmt.Errorf("resolving destination host: no addresses found for %q", host)
|
||||
}
|
||||
|
||||
destination = net.JoinHostPort(addrs[0], portStr)
|
||||
}
|
||||
|
||||
resolvedDestinationUDPAddress, err := net.ResolveUDPAddr("udp", destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving destination UDP address: %w", err)
|
||||
}
|
||||
|
||||
_, err = socket.WriteTo(payload, resolvedDestinationUDPAddress)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
return fmt.Errorf("writing payload to destination: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func netAddrToNetipAddrPort(addr net.Addr) (netip.AddrPort, error) {
|
||||
addrPort, err := netip.ParseAddrPort(addr.String())
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("parsing address: %w", err)
|
||||
}
|
||||
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
|
||||
}
|
||||
Reference in New Issue
Block a user