mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-15 16:04:08 +02:00
feat(socks5): UDP proxying (#3353)
This commit is contained in:
+249
-1
@@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,6 +24,7 @@ type socksConn struct {
|
||||
username string
|
||||
password string
|
||||
clientConn net.Conn
|
||||
udpRouter *udpRouter
|
||||
logger Logger
|
||||
}
|
||||
|
||||
@@ -109,11 +111,29 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||
return err
|
||||
}
|
||||
if request.command != connect {
|
||||
|
||||
switch request.command {
|
||||
case connect:
|
||||
err = c.handleConnectRequest(ctx, socksVersion, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handling %s request: %w", request.command, err)
|
||||
}
|
||||
return nil
|
||||
case udpAssociate:
|
||||
err = c.handleUDPAssociateRequest(ctx, socksVersion, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handling %s request: %w", request.command, err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
|
||||
return fmt.Errorf("command %s is not supported", request.command)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *socksConn) handleConnectRequest(ctx context.Context,
|
||||
socksVersion byte, request request,
|
||||
) error {
|
||||
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
|
||||
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
|
||||
if err != nil {
|
||||
@@ -176,6 +196,234 @@ func (c *socksConn) handleRequest(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *socksConn) handleUDPAssociateRequest(ctx context.Context,
|
||||
socksVersion byte, request request,
|
||||
) error {
|
||||
expectedAddrPort, err := udpAssociateExpectedClientEndpoint(request)
|
||||
if err != nil {
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, addressTypeNotSupported)
|
||||
return fmt.Errorf("deriving expected client address and port from request: %w", err)
|
||||
}
|
||||
|
||||
bindAddress, bindPort, bindAddrType, err := c.udpAssociationAddresses()
|
||||
if err != nil {
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||
return fmt.Errorf("getting udp association addresses: %w", err)
|
||||
}
|
||||
|
||||
association, err := c.udpRouter.registerAssociation(c.clientConn, expectedAddrPort)
|
||||
if err != nil {
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||
return fmt.Errorf("registering udp association: %w", err)
|
||||
}
|
||||
defer c.udpRouter.unregisterAssociation(association)
|
||||
|
||||
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded,
|
||||
bindAddrType, bindAddress, bindPort)
|
||||
if err != nil {
|
||||
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||
return fmt.Errorf("writing successful %s response: %w", udpAssociate, err)
|
||||
}
|
||||
|
||||
associationCtx, associationCancel := context.WithCancel(ctx)
|
||||
defer associationCancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Go(func() {
|
||||
c.udpRouter.runAssociationHandler(associationCtx, association)
|
||||
})
|
||||
|
||||
wg.Go(func() {
|
||||
_, _ = io.Copy(io.Discard, c.clientConn)
|
||||
associationCancel()
|
||||
})
|
||||
<-associationCtx.Done()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func udpAssociateExpectedClientEndpoint(request request) (expectedAddrPort netip.AddrPort, err error) {
|
||||
switch request.addressType {
|
||||
case ipv4, ipv6:
|
||||
expectedClientAddress, parseErr := netip.ParseAddr(request.destination)
|
||||
if parseErr != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("parsing destination address: %w", parseErr)
|
||||
}
|
||||
expectedClientAddress = expectedClientAddress.Unmap()
|
||||
if !expectedClientAddress.IsUnspecified() {
|
||||
return netip.AddrPortFrom(expectedClientAddress, request.port), nil
|
||||
}
|
||||
return netip.AddrPortFrom(netip.Addr{}, request.port), nil
|
||||
case domainName:
|
||||
if request.destination != "" || request.port != 0 {
|
||||
return netip.AddrPort{}, fmt.Errorf("domain name is not supported for UDP associate destination")
|
||||
}
|
||||
return netip.AddrPort{}, nil
|
||||
default:
|
||||
return netip.AddrPort{}, fmt.Errorf("address type %d is not supported", request.addressType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *socksConn) udpAssociationAddresses() (bindAddress string,
|
||||
bindPort uint16, bindAddrType addrType, err error,
|
||||
) {
|
||||
localAddress := c.udpRouter.localAddress().String()
|
||||
host, portString, err := net.SplitHostPort(localAddress)
|
||||
if err != nil {
|
||||
return "", 0, 0, fmt.Errorf("splitting local address: %w", err)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, 0, fmt.Errorf("parsing local port: %w", err)
|
||||
}
|
||||
bindAddress = host
|
||||
bindPort = uint16(port)
|
||||
if isUnspecifiedIPAddress(bindAddress) {
|
||||
controlLocalAddress := c.clientConn.LocalAddr().String()
|
||||
controlLocalHost, _, splitErr := net.SplitHostPort(controlLocalAddress)
|
||||
if splitErr != nil {
|
||||
return "", 0, 0, fmt.Errorf("splitting control connection local address: %w", splitErr)
|
||||
}
|
||||
bindAddress = controlLocalHost
|
||||
}
|
||||
|
||||
ipAddress := net.ParseIP(bindAddress)
|
||||
if ipAddress == nil {
|
||||
bindAddrType = domainName
|
||||
return bindAddress, bindPort, bindAddrType, nil
|
||||
}
|
||||
|
||||
if ipAddress.To4() != nil {
|
||||
bindAddrType = ipv4
|
||||
} else {
|
||||
bindAddrType = ipv6
|
||||
}
|
||||
|
||||
return bindAddress, bindPort, bindAddrType, nil
|
||||
}
|
||||
|
||||
func isUnspecifiedIPAddress(address string) bool {
|
||||
ipAddress, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return ipAddress.IsUnspecified()
|
||||
}
|
||||
|
||||
func decodeUDPDatagram(packet []byte) (destination string, payload []byte, err error) {
|
||||
const minimumPacketLength = 4
|
||||
if len(packet) < minimumPacketLength {
|
||||
return "", nil, fmt.Errorf("packet is too short: %d", len(packet))
|
||||
}
|
||||
if packet[0] != 0 || packet[1] != 0 {
|
||||
return "", nil, fmt.Errorf("reserved bytes are invalid: %x %x", packet[0], packet[1])
|
||||
}
|
||||
if packet[2] != 0 {
|
||||
return "", nil, fmt.Errorf("fragmentation is not supported")
|
||||
}
|
||||
|
||||
offset := 3
|
||||
addressType := addrType(packet[offset])
|
||||
offset++
|
||||
|
||||
switch addressType {
|
||||
case ipv4:
|
||||
const ipv4Length = 4
|
||||
if len(packet) < offset+ipv4Length+2 {
|
||||
return "", nil, fmt.Errorf("packet is too short for IPv4 address")
|
||||
}
|
||||
var ip [ipv4Length]byte
|
||||
copy(ip[:], packet[offset:offset+ipv4Length])
|
||||
destination = netip.AddrFrom4(ip).String()
|
||||
offset += ipv4Length
|
||||
case ipv6:
|
||||
const ipv6Length = 16
|
||||
if len(packet) < offset+ipv6Length+2 {
|
||||
return "", nil, fmt.Errorf("packet is too short for IPv6 address")
|
||||
}
|
||||
var ip [ipv6Length]byte
|
||||
copy(ip[:], packet[offset:offset+ipv6Length])
|
||||
destination = netip.AddrFrom16(ip).String()
|
||||
offset += ipv6Length
|
||||
case domainName:
|
||||
if len(packet) < offset+1 {
|
||||
return "", nil, fmt.Errorf("packet is too short for domain name length")
|
||||
}
|
||||
domainNameLength := int(packet[offset])
|
||||
offset++
|
||||
if len(packet) < offset+domainNameLength+2 {
|
||||
return "", nil, fmt.Errorf("packet is too short for domain name")
|
||||
}
|
||||
destination = string(packet[offset : offset+domainNameLength])
|
||||
offset += domainNameLength
|
||||
default:
|
||||
return "", nil, fmt.Errorf("address type is not supported: %d", addressType)
|
||||
}
|
||||
|
||||
port := binary.BigEndian.Uint16(packet[offset : offset+2])
|
||||
destination = net.JoinHostPort(destination, fmt.Sprint(port))
|
||||
offset += 2
|
||||
payload = packet[offset:]
|
||||
|
||||
return destination, payload, nil
|
||||
}
|
||||
|
||||
func encodeUDPDatagramToBuffer(writer io.Writer, sourceAddrPort netip.AddrPort,
|
||||
payload []byte,
|
||||
) error {
|
||||
address := sourceAddrPort.Addr()
|
||||
if !address.IsValid() {
|
||||
return errors.New("source address is not valid")
|
||||
}
|
||||
|
||||
err := writeUDPDatagramSourceAddress(writer, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing source address: %w", err)
|
||||
}
|
||||
|
||||
var portBytes [2]byte
|
||||
binary.BigEndian.PutUint16(portBytes[:], sourceAddrPort.Port())
|
||||
_, err = writer.Write(portBytes[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing destination port: %w", err)
|
||||
}
|
||||
|
||||
_, err = writer.Write(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing payload: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeUDPDatagramSourceAddress(writer io.Writer, address netip.Addr) error {
|
||||
var addrType addrType
|
||||
var addressBytes []byte
|
||||
switch {
|
||||
case address.Is4():
|
||||
addrType = ipv4
|
||||
array := address.As4()
|
||||
addressBytes = array[:]
|
||||
case address.Is6():
|
||||
addrType = ipv6
|
||||
array := address.As16()
|
||||
addressBytes = array[:]
|
||||
default:
|
||||
return fmt.Errorf("address type is not supported: %v", address)
|
||||
}
|
||||
|
||||
_, err := writer.Write([]byte{0, 0, 0, byte(addrType)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing header: %w", err)
|
||||
}
|
||||
_, err = writer.Write(addressBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing IP address: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
||||
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
|
||||
const headerLength = 2 // version + nMethods bytes
|
||||
|
||||
Reference in New Issue
Block a user