diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index d71831b9..7ee6875a 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -105,23 +105,21 @@ func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[ui return fmt.Errorf("allowing port in firewall: %w", err) } - var sourcePort, destinationPort uint16 + var destinationPort uint16 switch { case userRedirectionEnabled: // precedence over auto redirection - sourcePort = externalToInternalPorts[port] destinationPort = s.settings.ListeningPorts[i] - case port != externalToInternalPorts[port]: // auto redirection needed, source and destination ports differ - sourcePort = externalToInternalPorts[port] + case port != internalPort: // auto redirection needed, source and destination ports differ destinationPort = port default: // No redirection needed, source and destination ports are the same. continue } - err = s.portAllower.RedirectPort(ctx, s.settings.Interface, sourcePort, destinationPort) + err = s.portAllower.RedirectPort(ctx, s.settings.Interface, internalPort, destinationPort) if err != nil { return fmt.Errorf("redirecting port %d to %d in firewall: %w", - sourcePort, destinationPort, err) + internalPort, destinationPort, err) } } diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index fbf35884..841055a7 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -56,9 +56,13 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj // Extra ports must be non-symmetric, meaning that the internal port is // different from the external port. + const nonSymmetricPortStart = uint16(56789) + nonSymmetricPortStartMinusOne := nonSymmetricPortStart - 1 + if _, ok := p.internalToExternalPorts[nonSymmetricPortStart]; ok { + nonSymmetricPortStartMinusOne++ + } for i := uint16(1); i < objects.PortsCount; i++ { - const nonSymmetricPortStart uint16 = 56789 - 1 - internalPort := nonSymmetricPortStart + i + internalPort := nonSymmetricPortStartMinusOne + i const externalPort = 0 assignedInternalPort, assignedExternalPort, err := addPortMappingTCPUDP(ctx, client, logger, objects.Gateway, internalPort, externalPort, lifetime) @@ -75,6 +79,10 @@ func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger uti gateway netip.Addr, internalPort, externalPort uint16, lifetime time.Duration, ) (assignedInternalPort, assignedExternalPort uint16, err error) { var assignedLifetime time.Duration + protocolToExternalPort := map[string]uint16{ + "tcp": 0, + "udp": 0, + } for _, protocol := range [...]string{"udp", "tcp"} { protocolStr := strings.ToUpper(protocol) _, assignedInternalPort, assignedExternalPort, assignedLifetime, err = client.AddPortMapping( @@ -82,6 +90,7 @@ func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger uti if err != nil { return 0, 0, fmt.Errorf("adding %s port mapping: %w", protocolStr, err) } + protocolToExternalPort[protocol] = assignedExternalPort checkLifetime(logger, protocolStr, lifetime, assignedLifetime) if internalPort != assignedInternalPort { return 0, 0, fmt.Errorf("%s internal port requested as %d but received %d", @@ -91,6 +100,12 @@ func addPortMappingTCPUDP(ctx context.Context, client *natpmp.Client, logger uti protocolStr, externalPort, assignedExternalPort) } } + + if protocolToExternalPort["tcp"] != protocolToExternalPort["udp"] { + return 0, 0, fmt.Errorf("TCP and UDP external ports differ: %d and %d", + protocolToExternalPort["tcp"], protocolToExternalPort["udp"]) + } + return assignedInternalPort, assignedExternalPort, nil }