feat(protonvpn): use symmetric port forwarding for first port then asymmetric for next ports (#3345)

This commit is contained in:
Quentin McGaw
2026-05-24 16:47:58 -04:00
committed by GitHub
parent 6f5f518d1d
commit 2e20e2df66
4 changed files with 56 additions and 68 deletions
@@ -95,7 +95,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
return errors.New("port forwarding password is empty") return errors.New("port forwarding password is empty")
} }
case providers.Protonvpn: case providers.Protonvpn:
const maxPortsCount = 4 const maxPortsCount = 5
if p.PortsCount > maxPortsCount { if p.PortsCount > maxPortsCount {
return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount)
} }
+1 -1
View File
@@ -88,7 +88,7 @@ func (s *Settings) Validate(forStartup bool) (err error) {
return errors.New("password not set") return errors.New("password not set")
} }
case providers.Protonvpn: case providers.Protonvpn:
const maxPortsCount = 4 const maxPortsCount = 5
if s.PortsCount > maxPortsCount { if s.PortsCount > maxPortsCount {
return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount) return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount)
} }
+1 -5
View File
@@ -92,13 +92,9 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui
s.logger.Info(portPairsToString(internalToExternalPorts)) s.logger.Info(portPairsToString(internalToExternalPorts))
externalPorts := slices.Collect(maps.Values(internalToExternalPorts)) externalPorts := slices.Collect(maps.Values(internalToExternalPorts))
autoRedirectionNeeded := false
externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts)) externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts))
for internal, external := range internalToExternalPorts { for internal, external := range internalToExternalPorts {
externalToInternalPorts[external] = internal externalToInternalPorts[external] = internal
if internal != external {
autoRedirectionNeeded = true
}
} }
slices.Sort(externalPorts) slices.Sort(externalPorts)
userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0}) userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0})
@@ -114,7 +110,7 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui
case userRedirectionEnabled: // precedence over auto redirection case userRedirectionEnabled: // precedence over auto redirection
sourcePort = externalToInternalPorts[port] sourcePort = externalToInternalPorts[port]
destinationPort = s.settings.ListeningPorts[i] destinationPort = s.settings.ListeningPorts[i]
case autoRedirectionNeeded: case port != externalToInternalPorts[port]: // auto redirection needed, source and destination ports differ
sourcePort = externalToInternalPorts[port] sourcePort = externalToInternalPorts[port]
destinationPort = port destinationPort = port
default: default:
+53 -61
View File
@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"maps" "maps"
"net/netip"
"strings" "strings"
"time" "time"
@@ -12,14 +13,14 @@ import (
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
const nonSymmetricPortStart uint16 = 56789
// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway.
func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) (
internalToExternalPorts map[uint16]uint16, err error, internalToExternalPorts map[uint16]uint16, err error,
) { ) {
if !objects.CanPortForward { if !objects.CanPortForward {
return nil, errors.New("server does not support port forwarding") return nil, errors.New("server does not support port forwarding")
} else if objects.PortsCount == 0 {
return nil, nil //nolint:nilnil
} }
client := natpmp.New() client := natpmp.New()
@@ -39,38 +40,60 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
logger := objects.Logger logger := objects.Logger
logger.Debug("gateway external IPv4 address is " + externalIPv4Address.String()) logger.Debug("gateway external IPv4 address is " + externalIPv4Address.String())
const externalPort = 0
const lifetime = 60 * time.Second
p.internalToExternalPorts = make(map[uint16]uint16, objects.PortsCount) p.internalToExternalPorts = make(map[uint16]uint16, objects.PortsCount)
for i := range objects.PortsCount { const lifetime = 60 * time.Second
internalPort := nonSymmetricPortStart + i
protoToInternalPort := map[string]uint16{
"udp": 0,
"tcp": 0,
}
protoToExternalPort := maps.Clone(protoToInternalPort)
for protocol := range protoToExternalPort {
_, assignedInternalPort, assignedExternalPort, assignedLifetime, err := client.AddPortMapping(
ctx, objects.Gateway, protocol, internalPort, externalPort, lifetime)
if err != nil {
return nil, fmt.Errorf("adding %d/%d %s port mapping: %w",
i+1, objects.PortsCount, strings.ToUpper(protocol), err)
}
checkLifetime(logger, strings.ToUpper(protocol), lifetime, assignedLifetime)
checkInternalPort(logger, internalPort, assignedInternalPort)
protoToInternalPort[protocol] = assignedInternalPort
protoToExternalPort[protocol] = assignedExternalPort
}
checkInternalPorts(logger, protoToInternalPort["udp"], protoToInternalPort["tcp"]) // Only one port can be a symmetric mapping
checkExternalPorts(logger, protoToExternalPort["udp"], protoToExternalPort["tcp"]) const internalPort, externalPort = 0, 1
p.internalToExternalPorts[protoToInternalPort["tcp"]] = protoToExternalPort["tcp"] _, assignedExternalPort, err := addPortMappingTCPUDP(ctx,
client, logger, objects.Gateway, internalPort, externalPort, lifetime)
// Note the returned assignedInternalPort is always 0 in this case
if err != nil {
return nil, fmt.Errorf("adding first port mapping: %w", err)
}
p.internalToExternalPorts[assignedExternalPort] = assignedExternalPort
// Extra ports must be non-symmetric, meaning that the internal port is
// different from the external port.
for i := uint16(1); i < objects.PortsCount; i++ {
const nonSymmetricPortStart uint16 = 56789 - 1
internalPort := nonSymmetricPortStart + i
const externalPort = 0
assignedInternalPort, assignedExternalPort, err := addPortMappingTCPUDP(ctx,
client, logger, objects.Gateway, internalPort, externalPort, lifetime)
if err != nil {
return nil, fmt.Errorf("adding %d/%d port mapping: %w", i+1, objects.PortsCount, err)
}
p.internalToExternalPorts[assignedInternalPort] = assignedExternalPort
} }
return maps.Clone(p.internalToExternalPorts), nil return maps.Clone(p.internalToExternalPorts), nil
} }
func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger utils.Logger,
gateway netip.Addr, internalPort, externalPort uint16, lifetime time.Duration,
) (assignedInternalPort, assignedExternalPort uint16, err error) {
var assignedLifetime time.Duration
for _, protocol := range [...]string{"udp", "tcp"} {
protocolStr := strings.ToUpper(protocol)
_, assignedInternalPort, assignedExternalPort, assignedLifetime, err = client.AddPortMapping(
ctx, gateway, protocol, internalPort, externalPort, lifetime)
if err != nil {
return 0, 0, fmt.Errorf("adding %s port mapping: %w", protocolStr, err)
}
checkLifetime(logger, protocolStr, lifetime, assignedLifetime)
if internalPort != assignedInternalPort {
return 0, 0, fmt.Errorf("%s internal port requested as %d but received %d",
protocolStr, internalPort, assignedInternalPort)
} else if externalPort != 0 && externalPort != 1 && externalPort != assignedExternalPort {
return 0, 0, fmt.Errorf("%s external port requested as %d but received %d",
protocolStr, externalPort, assignedExternalPort)
}
}
return assignedInternalPort, assignedExternalPort, nil
}
func checkLifetime(logger utils.Logger, protocol string, func checkLifetime(logger utils.Logger, protocol string,
requested, actual time.Duration, requested, actual time.Duration,
) { ) {
@@ -81,27 +104,6 @@ func checkLifetime(logger utils.Logger, protocol string,
} }
} }
func checkInternalPort(logger utils.Logger, sent, received uint16) {
if sent != received {
logger.Warn(fmt.Sprintf("internal port assigned %d differs from requested internal port %d",
sent, received))
}
}
func checkInternalPorts(logger utils.Logger, udpPort, tcpPort uint16) {
if udpPort != tcpPort {
logger.Warn(fmt.Sprintf("UDP internal port %d differs from TCP internal port %d",
udpPort, tcpPort))
}
}
func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) {
if udpPort != tcpPort {
logger.Warn(fmt.Sprintf("UDP external port %d differs from TCP external port %d",
udpPort, tcpPort))
}
}
func (p *Provider) KeepPortForward(ctx context.Context, func (p *Provider) KeepPortForward(ctx context.Context,
objects utils.PortForwardObjects, objects utils.PortForwardObjects,
) (err error) { ) (err error) {
@@ -117,22 +119,12 @@ func (p *Provider) KeepPortForward(ctx context.Context,
} }
objects.Logger.Debug("refreshing forwarded ports since 45 seconds have elapsed") objects.Logger.Debug("refreshing forwarded ports since 45 seconds have elapsed")
networkProtocols := [...]string{"udp", "tcp"}
const lifetime = 60 * time.Second const lifetime = 60 * time.Second
for internalPort, externalPort := range p.internalToExternalPorts { for internalPort, externalPort := range p.internalToExternalPorts {
for _, networkProtocol := range networkProtocols { _, _, err := addPortMappingTCPUDP(ctx, client, logger, objects.Gateway, internalPort, externalPort, lifetime)
_, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := client.AddPortMapping( if err != nil {
ctx, objects.Gateway, networkProtocol, internalPort, externalPort, lifetime) return fmt.Errorf("refreshing port mapping for internal port %d and external port %d: %w",
if err != nil { internalPort, externalPort, err)
return fmt.Errorf("adding port mapping: %w", err)
}
checkLifetime(logger, networkProtocol, lifetime, assignedLiftetime)
if externalPort != assignedExternalPort {
return fmt.Errorf("external port changed from %d to %d", externalPort, assignedExternalPort)
} else if internalPort != assignedInternalPort {
return fmt.Errorf("internal port changed from %d (for external port %d) to %d",
internalPort, externalPort, assignedInternalPort)
}
} }
objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", externalPort)) objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", externalPort))
} }