diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index a08277e6..71850b75 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -95,7 +95,7 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) { return errors.New("port forwarding password is empty") } case providers.Protonvpn: - const maxPortsCount = 4 + const maxPortsCount = 5 if p.PortsCount > maxPortsCount { return fmt.Errorf("ports count too high: %d > %d", p.PortsCount, maxPortsCount) } diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index df1645eb..61ded963 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -88,7 +88,7 @@ func (s *Settings) Validate(forStartup bool) (err error) { return errors.New("password not set") } case providers.Protonvpn: - const maxPortsCount = 4 + const maxPortsCount = 5 if s.PortsCount > maxPortsCount { return fmt.Errorf("ports count too high: %d > %d", s.PortsCount, maxPortsCount) } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index a5932a8b..d71831b9 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -92,13 +92,9 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui s.logger.Info(portPairsToString(internalToExternalPorts)) externalPorts := slices.Collect(maps.Values(internalToExternalPorts)) - autoRedirectionNeeded := false externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts)) for internal, external := range internalToExternalPorts { externalToInternalPorts[external] = internal - if internal != external { - autoRedirectionNeeded = true - } } slices.Sort(externalPorts) userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0}) @@ -114,7 +110,7 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui case userRedirectionEnabled: // precedence over auto redirection sourcePort = externalToInternalPorts[port] destinationPort = s.settings.ListeningPorts[i] - case autoRedirectionNeeded: + case port != externalToInternalPorts[port]: // auto redirection needed, source and destination ports differ sourcePort = externalToInternalPorts[port] destinationPort = port default: diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index e3d24ff6..fbf35884 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "maps" + "net/netip" "strings" "time" @@ -12,14 +13,14 @@ import ( "github.com/qdm12/gluetun/internal/provider/utils" ) -const nonSymmetricPortStart uint16 = 56789 - // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( internalToExternalPorts map[uint16]uint16, err error, ) { if !objects.CanPortForward { return nil, errors.New("server does not support port forwarding") + } else if objects.PortsCount == 0 { + return nil, nil //nolint:nilnil } client := natpmp.New() @@ -39,38 +40,60 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj logger := objects.Logger logger.Debug("gateway external IPv4 address is " + externalIPv4Address.String()) - const externalPort = 0 - const lifetime = 60 * time.Second p.internalToExternalPorts = make(map[uint16]uint16, objects.PortsCount) - for i := range objects.PortsCount { - internalPort := nonSymmetricPortStart + i - protoToInternalPort := map[string]uint16{ - "udp": 0, - "tcp": 0, - } - protoToExternalPort := maps.Clone(protoToInternalPort) - for protocol := range protoToExternalPort { - _, assignedInternalPort, assignedExternalPort, assignedLifetime, err := client.AddPortMapping( - ctx, objects.Gateway, protocol, internalPort, externalPort, lifetime) - if err != nil { - return nil, fmt.Errorf("adding %d/%d %s port mapping: %w", - i+1, objects.PortsCount, strings.ToUpper(protocol), err) - } - checkLifetime(logger, strings.ToUpper(protocol), lifetime, assignedLifetime) - checkInternalPort(logger, internalPort, assignedInternalPort) - protoToInternalPort[protocol] = assignedInternalPort - protoToExternalPort[protocol] = assignedExternalPort - } + const lifetime = 60 * time.Second - checkInternalPorts(logger, protoToInternalPort["udp"], protoToInternalPort["tcp"]) - checkExternalPorts(logger, protoToExternalPort["udp"], protoToExternalPort["tcp"]) - p.internalToExternalPorts[protoToInternalPort["tcp"]] = protoToExternalPort["tcp"] + // Only one port can be a symmetric mapping + const internalPort, externalPort = 0, 1 + _, assignedExternalPort, err := addPortMappingTCPUDP(ctx, + client, logger, objects.Gateway, internalPort, externalPort, lifetime) + // Note the returned assignedInternalPort is always 0 in this case + if err != nil { + return nil, fmt.Errorf("adding first port mapping: %w", err) + } + p.internalToExternalPorts[assignedExternalPort] = assignedExternalPort + + // Extra ports must be non-symmetric, meaning that the internal port is + // different from the external port. + for i := uint16(1); i < objects.PortsCount; i++ { + const nonSymmetricPortStart uint16 = 56789 - 1 + internalPort := nonSymmetricPortStart + i + const externalPort = 0 + assignedInternalPort, assignedExternalPort, err := addPortMappingTCPUDP(ctx, + client, logger, objects.Gateway, internalPort, externalPort, lifetime) + if err != nil { + return nil, fmt.Errorf("adding %d/%d port mapping: %w", i+1, objects.PortsCount, err) + } + p.internalToExternalPorts[assignedInternalPort] = assignedExternalPort } return maps.Clone(p.internalToExternalPorts), nil } +func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger utils.Logger, + gateway netip.Addr, internalPort, externalPort uint16, lifetime time.Duration, +) (assignedInternalPort, assignedExternalPort uint16, err error) { + var assignedLifetime time.Duration + for _, protocol := range [...]string{"udp", "tcp"} { + protocolStr := strings.ToUpper(protocol) + _, assignedInternalPort, assignedExternalPort, assignedLifetime, err = client.AddPortMapping( + ctx, gateway, protocol, internalPort, externalPort, lifetime) + if err != nil { + return 0, 0, fmt.Errorf("adding %s port mapping: %w", protocolStr, err) + } + checkLifetime(logger, protocolStr, lifetime, assignedLifetime) + if internalPort != assignedInternalPort { + return 0, 0, fmt.Errorf("%s internal port requested as %d but received %d", + protocolStr, internalPort, assignedInternalPort) + } else if externalPort != 0 && externalPort != 1 && externalPort != assignedExternalPort { + return 0, 0, fmt.Errorf("%s external port requested as %d but received %d", + protocolStr, externalPort, assignedExternalPort) + } + } + return assignedInternalPort, assignedExternalPort, nil +} + func checkLifetime(logger utils.Logger, protocol string, requested, actual time.Duration, ) { @@ -81,27 +104,6 @@ func checkLifetime(logger utils.Logger, protocol string, } } -func checkInternalPort(logger utils.Logger, sent, received uint16) { - if sent != received { - logger.Warn(fmt.Sprintf("internal port assigned %d differs from requested internal port %d", - sent, received)) - } -} - -func checkInternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { - if udpPort != tcpPort { - logger.Warn(fmt.Sprintf("UDP internal port %d differs from TCP internal port %d", - udpPort, tcpPort)) - } -} - -func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { - if udpPort != tcpPort { - logger.Warn(fmt.Sprintf("UDP external port %d differs from TCP external port %d", - udpPort, tcpPort)) - } -} - func (p *Provider) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects, ) (err error) { @@ -117,22 +119,12 @@ func (p *Provider) KeepPortForward(ctx context.Context, } objects.Logger.Debug("refreshing forwarded ports since 45 seconds have elapsed") - networkProtocols := [...]string{"udp", "tcp"} const lifetime = 60 * time.Second for internalPort, externalPort := range p.internalToExternalPorts { - for _, networkProtocol := range networkProtocols { - _, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := client.AddPortMapping( - ctx, objects.Gateway, networkProtocol, internalPort, externalPort, lifetime) - if err != nil { - return fmt.Errorf("adding port mapping: %w", err) - } - checkLifetime(logger, networkProtocol, lifetime, assignedLiftetime) - if externalPort != assignedExternalPort { - return fmt.Errorf("external port changed from %d to %d", externalPort, assignedExternalPort) - } else if internalPort != assignedInternalPort { - return fmt.Errorf("internal port changed from %d (for external port %d) to %d", - internalPort, externalPort, assignedInternalPort) - } + _, _, err := addPortMappingTCPUDP(ctx, client, logger, objects.Gateway, internalPort, externalPort, lifetime) + if err != nil { + return fmt.Errorf("refreshing port mapping for internal port %d and external port %d: %w", + internalPort, externalPort, err) } objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", externalPort)) }