From d5eeec6fb3798f0bc6379221f79aa5632efe1f32 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 18 Apr 2026 02:36:06 +0200 Subject: [PATCH] feat(protonvpn): support up to 5 forwarded ports (#3208) --- Dockerfile | 3 +- .../configuration/settings/portforward.go | 87 ++++++++++---- internal/portforward/loop.go | 11 +- internal/portforward/service/interfaces.go | 2 +- internal/portforward/service/service.go | 6 +- internal/portforward/service/settings.go | 41 ++++++- internal/portforward/service/start.go | 58 ++++++--- internal/portforward/service/stop.go | 4 +- .../provider/perfectprivacy/portforward.go | 9 +- .../privateinternetaccess/portforward.go | 4 +- internal/provider/privatevpn/portforward.go | 5 +- .../provider/privatevpn/portforward_test.go | 12 +- internal/provider/protonvpn/portforward.go | 113 +++++++++++------- internal/provider/protonvpn/provider.go | 2 +- internal/provider/utils/portforward.go | 2 + internal/vpn/interfaces.go | 2 +- internal/vpn/portforward.go | 2 +- 17 files changed, 254 insertions(+), 109 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8e3986c9..7cbb0ccd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -164,7 +164,8 @@ ENV VPN_SERVICE_PROVIDER=pia \ VPN_PORT_FORWARDING_PROVIDER= \ VPN_PORT_FORWARDING_UP_COMMAND= \ VPN_PORT_FORWARDING_DOWN_COMMAND= \ - VPN_PORT_FORWARDING_LISTENING_PORT=0 \ + VPN_PORT_FORWARDING_LISTENING_PORTS=0 \ + VPN_PORT_FORWARDING_PORTS_COUNT=1 \ VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \ # PMTUD PMTUD_ICMP_ADDRESSES=1.1.1.1,8.8.8.8 \ diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index d445d657..9e7d9202 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -1,8 +1,10 @@ package settings import ( + "errors" "fmt" "path/filepath" + "slices" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gosettings" @@ -37,16 +39,28 @@ type PortForwarding struct { // It can be the empty string to indicate to NOT run a command. // It cannot be nil in the internal state. DownCommand *string `json:"down_command"` - // ListeningPort is the port traffic would be redirected to from the - // forwarded port. The redirection is disabled if it is set to 0, which - // is its default as well. - ListeningPort *uint16 `json:"listening_port"` + // ListeningPorts are the ports traffic would be redirected to from the + // forwarded ports. The redirection is disabled if it is the slice [0], + // which is its default as well. If set and not [0], its length must match + // the PortsCount value, such that each forwarded port is redirected to + // the corresponding listening port. + ListeningPorts []uint16 `json:"listening_port"` + // PortsCount is the number of ports to forward. It is optional for ProtonVPN + // and be between 1 and 5. For other providers, it must be set to 1 if port + // forwarding is enabled. + PortsCount uint16 `json:"ports_count"` // Username is only used for Private Internet Access port forwarding. Username string `json:"username"` // Password is only used for Private Internet Access port forwarding. Password string `json:"password"` } +var ( + ErrPortsCountTooHigh = errors.New("ports count too high") + ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count") + ErrListeningPortZero = errors.New("listening port cannot be 0") +) + func (p PortForwarding) Validate(vpnProvider string) (err error) { if !*p.Enabled { return nil @@ -75,13 +89,36 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) { } } - if providerSelected == providers.PrivateInternetAccess { + switch providerSelected { + case providers.PrivateInternetAccess: + const maxPortsCount = 1 switch { + case p.PortsCount > maxPortsCount: + return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) case p.Username == "": return fmt.Errorf("%w", ErrPortForwardingUserEmpty) case p.Password == "": return fmt.Errorf("%w", ErrPortForwardingPasswordEmpty) } + case providers.Protonvpn: + const maxPortsCount = 4 + if p.PortsCount > maxPortsCount { + return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) + } + default: + const maxPortsCount = 1 + if p.PortsCount > maxPortsCount { + return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, p.PortsCount, maxPortsCount) + } + } + + if !slices.Equal(p.ListeningPorts, []uint16{0}) { + switch { + case len(p.ListeningPorts) != int(p.PortsCount): + return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(p.ListeningPorts), p.PortsCount) + case slices.Contains(p.ListeningPorts, 0): + return fmt.Errorf("%w: in %v", ErrListeningPortZero, p.ListeningPorts) + } } return nil @@ -89,14 +126,14 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) { func (p *PortForwarding) Copy() (copied PortForwarding) { return PortForwarding{ - Enabled: gosettings.CopyPointer(p.Enabled), - Provider: gosettings.CopyPointer(p.Provider), - Filepath: gosettings.CopyPointer(p.Filepath), - UpCommand: gosettings.CopyPointer(p.UpCommand), - DownCommand: gosettings.CopyPointer(p.DownCommand), - ListeningPort: gosettings.CopyPointer(p.ListeningPort), - Username: p.Username, - Password: p.Password, + Enabled: gosettings.CopyPointer(p.Enabled), + Provider: gosettings.CopyPointer(p.Provider), + Filepath: gosettings.CopyPointer(p.Filepath), + UpCommand: gosettings.CopyPointer(p.UpCommand), + DownCommand: gosettings.CopyPointer(p.DownCommand), + ListeningPorts: gosettings.CopySlice(p.ListeningPorts), + Username: p.Username, + Password: p.Password, } } @@ -106,7 +143,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) { p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand) p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand) - p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort) + p.ListeningPorts = gosettings.OverrideWithSlice(p.ListeningPorts, other.ListeningPorts) p.Username = gosettings.OverrideWithComparable(p.Username, other.Username) p.Password = gosettings.OverrideWithComparable(p.Password, other.Password) } @@ -117,7 +154,8 @@ func (p *PortForwarding) setDefaults() { p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port") p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "") p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "") - p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0) + p.ListeningPorts = gosettings.DefaultSlice(p.ListeningPorts, []uint16{0}) // disabled + p.PortsCount = gosettings.DefaultComparable(p.PortsCount, 1) } func (p PortForwarding) String() string { @@ -131,11 +169,14 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) { node = gotree.New("Automatic port forwarding settings:") - listeningPort := "disabled" - if *p.ListeningPort != 0 { - listeningPort = fmt.Sprintf("%d", *p.ListeningPort) + node.Appendf("Number of ports to be forwarded: %d", p.PortsCount) + + if !slices.Equal(p.ListeningPorts, []uint16{0}) { + redirNode := node.Appendf("Redirection for listening ports:") + for i, port := range p.ListeningPorts { + redirNode.Appendf("Port #%d -> %d", i+1, port) + } } - node.Appendf("Redirection listening port: %s", listeningPort) if *p.Provider == "" { node.Appendf("Use port forwarding code for current provider") @@ -190,7 +231,13 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) { p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND", reader.ForceLowercase(false)) - p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT") + p.ListeningPorts, err = r.CSVUint16("VPN_PORT_FORWARDING_LISTENING_PORTS", + reader.RetroKeys("VPN_PORT_FORWARDING_LISTENING_PORT")) + if err != nil { + return err + } + + p.PortsCount, err = r.Uint16("VPN_PORT_FORWARDING_PORTS_COUNT") if err != nil { return err } diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 2e19a4be..e866bea9 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -42,11 +42,12 @@ func NewLoop(settings settings.PortForwarding, routing Routing, settings: Settings{ VPNIsUp: ptrTo(false), Service: service.Settings{ - Enabled: settings.Enabled, - Filepath: *settings.Filepath, - UpCommand: *settings.UpCommand, - DownCommand: *settings.DownCommand, - ListeningPort: *settings.ListeningPort, + Enabled: settings.Enabled, + Filepath: *settings.Filepath, + UpCommand: *settings.UpCommand, + DownCommand: *settings.DownCommand, + ListeningPorts: settings.ListeningPorts, + PortsCount: settings.PortsCount, }, }, routing: routing, diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index f1ea7298..05d09963 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -30,7 +30,7 @@ type Logger interface { type PortForwarder interface { Name() string PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - ports []uint16, err error) + internalToExternalPorts map[uint16]uint16, err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index e178c16d..06d22770 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -69,7 +69,11 @@ func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err er return fmt.Errorf("cleaning up: %w", err) } - err = s.onNewPorts(ctx, ports) + internalToExternalPorts := make(map[uint16]uint16, len(ports)) + for _, port := range ports { + internalToExternalPorts[port] = port + } + err = s.onNewPorts(ctx, internalToExternalPorts) if err != nil { return fmt.Errorf("handling new ports: %w", err) } diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index 2847be51..5ec89ae9 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -3,6 +3,7 @@ package service import ( "errors" "fmt" + "slices" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gosettings" @@ -17,7 +18,8 @@ type Settings struct { Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example ServerName string // needed for PIA CanPortForward bool // needed for PIA - ListeningPort uint16 + ListeningPorts []uint16 + PortsCount uint16 Username string // needed for PIA Password string // needed for PIA } @@ -31,7 +33,8 @@ func (s Settings) Copy() (copied Settings) { copied.Interface = s.Interface copied.ServerName = s.ServerName copied.CanPortForward = s.CanPortForward - copied.ListeningPort = s.ListeningPort + copied.ListeningPorts = gosettings.CopySlice(s.ListeningPorts) + copied.PortsCount = s.PortsCount copied.Username = s.Username copied.Password = s.Password return copied @@ -46,7 +49,8 @@ func (s *Settings) OverrideWith(update Settings) { s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface) s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName) s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward) - s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort) + s.ListeningPorts = gosettings.OverrideWithSlice(s.ListeningPorts, update.ListeningPorts) + s.PortsCount = gosettings.OverrideWithComparable(s.PortsCount, update.PortsCount) s.Username = gosettings.OverrideWithComparable(s.Username, update.Username) s.Password = gosettings.OverrideWithComparable(s.Password, update.Password) } @@ -58,6 +62,10 @@ var ( ErrPasswordNotSet = errors.New("password not set") ErrFilepathNotSet = errors.New("file path not set") ErrInterfaceNotSet = errors.New("interface not set") + ErrPortsCountZero = errors.New("ports count cannot be zero") + ErrPortsCountTooHigh = errors.New("ports count too high") + ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count") + ErrListeningPortZero = errors.New("listening port cannot be 0") ) func (s *Settings) Validate(forStartup bool) (err error) { @@ -78,7 +86,12 @@ func (s *Settings) Validate(forStartup bool) (err error) { return fmt.Errorf("%w", ErrPortForwarderNotSet) case s.Interface == "": return fmt.Errorf("%w", ErrInterfaceNotSet) - case s.PortForwarder.Name() == providers.PrivateInternetAccess: + case s.PortsCount == 0: + return fmt.Errorf("%w", ErrPortsCountZero) + } + + switch s.PortForwarder.Name() { + case providers.PrivateInternetAccess: switch { case s.ServerName == "": return fmt.Errorf("%w", ErrServerNameNotSet) @@ -87,6 +100,26 @@ func (s *Settings) Validate(forStartup bool) (err error) { case s.Password == "": return fmt.Errorf("%w", ErrPasswordNotSet) } + case providers.Protonvpn: + const maxPortsCount = 4 + if s.PortsCount > maxPortsCount { + return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount) + } + default: + const maxPortsCount = 1 + if s.PortsCount > maxPortsCount { + return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount) + } } + + if !slices.Equal(s.ListeningPorts, []uint16{0}) { + switch { + case len(s.ListeningPorts) != int(s.PortsCount): + return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount) + case slices.Contains(s.ListeningPorts, 0): + return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts) + } + } + return nil } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index 68e54db4..bef256ee 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "maps" "slices" "github.com/qdm12/gluetun/internal/netlink" @@ -42,8 +43,9 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) CanPortForward: s.settings.CanPortForward, Username: s.settings.Username, Password: s.settings.Password, + PortsCount: s.settings.PortsCount, } - ports, err := s.settings.PortForwarder.PortForward(ctx, obj) + internalToExternalPorts, err := s.settings.PortForwarder.PortForward(ctx, obj) if err != nil { return nil, fmt.Errorf("port forwarding for the first time: %w", err) } @@ -51,7 +53,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) s.portMutex.Lock() defer s.portMutex.Unlock() - err = s.onNewPorts(ctx, ports) + err = s.onNewPorts(ctx, internalToExternalPorts) if err != nil { return nil, err } @@ -86,36 +88,60 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) return runErrorCh, nil } -func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) { - slices.Sort(ports) +func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[uint16]uint16) (err error) { + autoRedirectionNeeded := false + externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts)) + for internal, external := range internalToExternalPorts { + externalToInternalPorts[external] = internal + if internal != external { + autoRedirectionNeeded = true + } + } - s.logger.Info(portsToString(ports)) + externalPorts := slices.Collect(maps.Keys(externalToInternalPorts)) + slices.Sort(externalPorts) - for _, port := range ports { - err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) + s.logger.Info(portsToString(externalPorts)) + + userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0}) + for i, port := range externalPorts { + internalPort := externalToInternalPorts[port] + err = s.portAllower.SetAllowedPort(ctx, internalPort, s.settings.Interface) if err != nil { return fmt.Errorf("allowing port in firewall: %w", err) } - if s.settings.ListeningPort != 0 { - err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) - if err != nil { - return fmt.Errorf("redirecting port in firewall: %w", err) - } + var sourcePort, destinationPort uint16 + switch { + case userRedirectionEnabled: // precedence over auto redirection + sourcePort = externalToInternalPorts[port] + destinationPort = s.settings.ListeningPorts[i] + case autoRedirectionNeeded: + sourcePort = externalToInternalPorts[port] + destinationPort = port + default: + // No redirection needed, source and destination ports are the same. + continue + } + + err = s.portAllower.RedirectPort(ctx, s.settings.Interface, sourcePort, destinationPort) + if err != nil { + return fmt.Errorf("redirecting port %d to %d in firewall: %w", + sourcePort, destinationPort, err) } } - err = s.writePortForwardedFile(ports) + err = s.writePortForwardedFile(externalPorts) if err != nil { _ = s.cleanup() return fmt.Errorf("writing port file: %w", err) } - s.ports = make([]uint16, len(ports)) - copy(s.ports, ports) + s.ports = make([]uint16, len(internalToExternalPorts)) + copy(s.ports, externalPorts) if s.settings.UpCommand != "" { - err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface) + err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, externalPorts, s.settings.Interface) if err != nil { err = fmt.Errorf("running up command: %w", err) s.logger.Error(err.Error()) diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go index e74f904f..fecc03ef 100644 --- a/internal/portforward/service/stop.go +++ b/internal/portforward/service/stop.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "slices" "time" ) @@ -38,13 +39,14 @@ func (s *Service) cleanup() (err error) { } } + redirectionWasEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0}) for _, port := range s.ports { err = s.portAllower.RemoveAllowedPort(context.Background(), port) if err != nil { return fmt.Errorf("blocking previous port in firewall: %w", err) } - if s.settings.ListeningPort != 0 { + if redirectionWasEnabled { ctx := context.Background() const listeningPort = 0 // 0 to clear the redirection err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, listeningPort) diff --git a/internal/provider/perfectprivacy/portforward.go b/internal/provider/perfectprivacy/portforward.go index 9716ab01..365b0fc9 100644 --- a/internal/provider/perfectprivacy/portforward.go +++ b/internal/provider/perfectprivacy/portforward.go @@ -10,12 +10,17 @@ import ( // PortForward calculates and returns the VPN server side ports forwarded. func (p *Provider) PortForward(_ context.Context, objects utils.PortForwardObjects, -) (ports []uint16, err error) { +) (internalToExternalPorts map[uint16]uint16, err error) { if !objects.InternalIP.IsValid() { panic("internal ip is not set") } - return internalIPToPorts(objects.InternalIP), nil + ports := internalIPToPorts(objects.InternalIP) + internalToExternalPorts = make(map[uint16]uint16, len(ports)) + for _, port := range ports { + internalToExternalPorts[port] = port + } + return internalToExternalPorts, nil } func (p *Provider) KeepPortForward(ctx context.Context, diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index ad0912ec..be2a4b5b 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -26,7 +26,7 @@ var ErrServerNameNotFound = errors.New("server name not found in servers") // PortForward obtains a VPN server side port forwarded from PIA. func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects, -) (ports []uint16, err error) { +) (internalToExternalPorts map[uint16]uint16, err error) { switch { case objects.ServerName == "": panic("server name cannot be empty") @@ -84,7 +84,7 @@ func (p *Provider) PortForward(ctx context.Context, return nil, fmt.Errorf("binding port: %w", err) } - return []uint16{data.Port}, nil + return map[uint16]uint16{data.Port: data.Port}, nil } var ErrPortForwardedExpired = errors.New("port forwarded data expired") diff --git a/internal/provider/privatevpn/portforward.go b/internal/provider/privatevpn/portforward.go index f68c1d6a..9982dc5a 100644 --- a/internal/provider/privatevpn/portforward.go +++ b/internal/provider/privatevpn/portforward.go @@ -22,7 +22,7 @@ var ErrPortForwardedNotFound = errors.New("port forwarded not found") // PortForward obtains a VPN server side port forwarded from the PrivateVPN API. // It returns 0 if all ports are to forwarded on a dedicated server IP. func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - ports []uint16, err error, + internalToExternalPorts map[uint16]uint16, err error, ) { // Define a timeout since the default client has a large timeout and we don't // want to wait too long. @@ -75,7 +75,8 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj if err != nil { return nil, fmt.Errorf("parsing port: %w", err) } - return []uint16{uint16(portUint64)}, nil + port := uint16(portUint64) + return map[uint16]uint16{port: port}, nil } func (p *Provider) KeepPortForward(ctx context.Context, diff --git a/internal/provider/privatevpn/portforward_test.go b/internal/provider/privatevpn/portforward_test.go index 5f476e33..bb8f69ee 100644 --- a/internal/provider/privatevpn/portforward_test.go +++ b/internal/provider/privatevpn/portforward_test.go @@ -28,10 +28,10 @@ func Test_Provider_PortForward(t *testing.T) { cancel() testCases := map[string]struct { - ctx context.Context - objects utils.PortForwardObjects - ports []uint16 - errMessage string + ctx context.Context + objects utils.PortForwardObjects + internalToExternalPorts map[uint16]uint16 + errMessage string }{ "canceled context": { ctx: canceledCtx, @@ -192,7 +192,7 @@ func Test_Provider_PortForward(t *testing.T) { }), }, }, - ports: []uint16{61527}, + internalToExternalPorts: map[uint16]uint16{61527: 61527}, }, } for name, testCase := range testCases { @@ -203,7 +203,7 @@ func Test_Provider_PortForward(t *testing.T) { ports, err := provider.PortForward(testCase.ctx, testCase.objects) - assert.Equal(t, testCase.ports, ports) + assert.Equal(t, testCase.internalToExternalPorts, ports) if testCase.errMessage != "" { assert.EqualError(t, err, testCase.errMessage) } else { diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index 68746c82..5479d65a 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "strings" "time" @@ -13,17 +14,18 @@ import ( var ErrServerPortForwardNotSupported = errors.New("server does not support port forwarding") +const nonSymmetricPortStart uint16 = 56789 + // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - ports []uint16, err error, + internalToExternalPorts map[uint16]uint16, err error, ) { if !objects.CanPortForward { return nil, fmt.Errorf("%w", ErrServerPortForwardNotSupported) } client := natpmp.New() - _, externalIPv4Address, err := client.ExternalAddress(ctx, - objects.Gateway) + _, externalIPv4Address, err := client.ExternalAddress(ctx, objects.Gateway) if err != nil { switch { case strings.HasSuffix(err.Error(), "connection refused"): @@ -38,29 +40,37 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj logger := objects.Logger - logger.Info("gateway external IPv4 address is " + externalIPv4Address.String()) - const internalPort, externalPort = 0, 1 + logger.Debug("gateway external IPv4 address is " + externalIPv4Address.String()) + const externalPort = 0 const lifetime = 60 * time.Second - _, _, assignedUDPExternalPort, assignedLifetime, err := client.AddPortMapping(ctx, objects.Gateway, "udp", - internalPort, externalPort, lifetime) - if err != nil { - return nil, fmt.Errorf("adding UDP port mapping: %w", err) + p.internalToExternalPorts = make(map[uint16]uint16, objects.PortsCount) + for i := range objects.PortsCount { + internalPort := nonSymmetricPortStart + i + protoToInternalPort := map[string]uint16{ + "udp": 0, + "tcp": 0, + } + protoToExternalPort := maps.Clone(protoToInternalPort) + for protocol := range protoToExternalPort { + _, assignedInternalPort, assignedExternalPort, assignedLifetime, err := client.AddPortMapping( + ctx, objects.Gateway, protocol, internalPort, externalPort, lifetime) + if err != nil { + return nil, fmt.Errorf("adding %d/%d %s port mapping: %w", + i+1, objects.PortsCount, strings.ToUpper(protocol), err) + } + checkLifetime(logger, strings.ToUpper(protocol), lifetime, assignedLifetime) + checkInternalPort(logger, internalPort, assignedInternalPort) + protoToInternalPort[protocol] = assignedInternalPort + protoToExternalPort[protocol] = assignedExternalPort + } + + checkInternalPorts(logger, protoToInternalPort["udp"], protoToInternalPort["tcp"]) + checkExternalPorts(logger, protoToExternalPort["udp"], protoToExternalPort["tcp"]) + p.internalToExternalPorts[protoToInternalPort["tcp"]] = protoToExternalPort["tcp"] } - checkLifetime(logger, "UDP", lifetime, assignedLifetime) - _, _, assignedTCPExternalPort, assignedLifetime, err := client.AddPortMapping(ctx, objects.Gateway, "tcp", - internalPort, externalPort, lifetime) - if err != nil { - return nil, fmt.Errorf("adding TCP port mapping: %w", err) - } - checkLifetime(logger, "TCP", lifetime, assignedLifetime) - - checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort) - - p.portForwarded = assignedTCPExternalPort - - return []uint16{assignedTCPExternalPort}, nil + return maps.Clone(p.internalToExternalPorts), nil } func checkLifetime(logger utils.Logger, protocol string, @@ -73,6 +83,20 @@ func checkLifetime(logger utils.Logger, protocol string, } } +func checkInternalPort(logger utils.Logger, sent, received uint16) { + if sent != received { + logger.Warn(fmt.Sprintf("internal port assigned %d differs from requested internal port %d", + sent, received)) + } +} + +func checkInternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { + if udpPort != tcpPort { + logger.Warn(fmt.Sprintf("UDP internal port %d differs from TCP internal port %d", + udpPort, tcpPort)) + } +} + func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { if udpPort != tcpPort { logger.Warn(fmt.Sprintf("UDP external port %d differs from TCP external port %d", @@ -80,7 +104,10 @@ func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { } } -var ErrExternalPortChanged = errors.New("external port changed") +var ( + ErrInternalPortChanged = errors.New("internal port changed") + ErrExternalPortChanged = errors.New("external port changed") +) func (p *Provider) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects, @@ -96,32 +123,28 @@ func (p *Provider) KeepPortForward(ctx context.Context, case <-timer.C: } - objects.Logger.Debug("refreshing port forward since 45 seconds have elapsed") - networkProtocols := []string{"udp", "tcp"} - const internalPort = 0 + objects.Logger.Debug("refreshing forwarded ports since 45 seconds have elapsed") + networkProtocols := [...]string{"udp", "tcp"} const lifetime = 60 * time.Second - - for _, networkProtocol := range networkProtocols { - _, _, assignedExternalPort, assignedLiftetime, err := client.AddPortMapping(ctx, objects.Gateway, networkProtocol, - internalPort, p.portForwarded, lifetime) - if err != nil { - return fmt.Errorf("adding port mapping: %w", err) - } - - if assignedLiftetime != lifetime { - logger.Warn(fmt.Sprintf("assigned lifetime %s differs"+ - " from requested lifetime %s", - assignedLiftetime, lifetime)) - } - - if p.portForwarded != assignedExternalPort { - return fmt.Errorf("%w: %d changed to %d", - ErrExternalPortChanged, p.portForwarded, assignedExternalPort) + for internalPort, externalPort := range p.internalToExternalPorts { + for _, networkProtocol := range networkProtocols { + _, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := client.AddPortMapping( + ctx, objects.Gateway, networkProtocol, internalPort, externalPort, lifetime) + if err != nil { + return fmt.Errorf("adding port mapping: %w", err) + } + checkLifetime(logger, networkProtocol, lifetime, assignedLiftetime) + if externalPort != assignedExternalPort { + return fmt.Errorf("%w: %d changed to %d", + ErrExternalPortChanged, externalPort, assignedExternalPort) + } else if internalPort != assignedInternalPort { + return fmt.Errorf("%w: %d (for external port %d) changed to %d", + ErrInternalPortChanged, internalPort, externalPort, assignedInternalPort) + } } + objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", externalPort)) } - objects.Logger.Debug(fmt.Sprintf("port forwarded %d maintained", p.portForwarded)) - timer.Reset(refreshTimeout) } } diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index 3e646bc2..8af0457f 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -13,7 +13,7 @@ type Provider struct { storage common.Storage randSource rand.Source common.Fetcher - portForwarded uint16 + internalToExternalPorts map[uint16]uint16 } func New(storage common.Storage, randSource rand.Source, diff --git a/internal/provider/utils/portforward.go b/internal/provider/utils/portforward.go index 6320f3e1..6f03646f 100644 --- a/internal/provider/utils/portforward.go +++ b/internal/provider/utils/portforward.go @@ -25,6 +25,8 @@ type PortForwardObjects struct { Username string // Password is used by Private Internet Access for port forwarding. Password string + // PortsCount is used by ProtonVPN for port forwarding. + PortsCount uint16 } type Routing interface { diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 57d7b0d6..92d34480 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -52,7 +52,7 @@ type Provider interface { type PortForwarder interface { Name() string PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - ports []uint16, err error) + internalToExternalPorts map[uint16]uint16, err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index 5c3ffbb4..584a8225 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -62,7 +62,7 @@ func (n *noPortForwarder) Name() string { } func (n *noPortForwarder) PortForward(context.Context, pfutils.PortForwardObjects) ( - ports []uint16, err error, + internalToExternalPorts map[uint16]uint16, err error, ) { return nil, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) }