mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-07 04:20:12 +02:00
feat(protonvpn): support up to 5 forwarded ports (#3208)
This commit is contained in:
@@ -42,11 +42,12 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
settings: Settings{
|
||||
VPNIsUp: ptrTo(false),
|
||||
Service: service.Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Filepath: *settings.Filepath,
|
||||
UpCommand: *settings.UpCommand,
|
||||
DownCommand: *settings.DownCommand,
|
||||
ListeningPort: *settings.ListeningPort,
|
||||
Enabled: settings.Enabled,
|
||||
Filepath: *settings.Filepath,
|
||||
UpCommand: *settings.UpCommand,
|
||||
DownCommand: *settings.DownCommand,
|
||||
ListeningPorts: settings.ListeningPorts,
|
||||
PortsCount: settings.PortsCount,
|
||||
},
|
||||
},
|
||||
routing: routing,
|
||||
|
||||
@@ -30,7 +30,7 @@ type Logger interface {
|
||||
type PortForwarder interface {
|
||||
Name() string
|
||||
PortForward(ctx context.Context, objects utils.PortForwardObjects) (
|
||||
ports []uint16, err error)
|
||||
internalToExternalPorts map[uint16]uint16, err error)
|
||||
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,11 @@ func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err er
|
||||
return fmt.Errorf("cleaning up: %w", err)
|
||||
}
|
||||
|
||||
err = s.onNewPorts(ctx, ports)
|
||||
internalToExternalPorts := make(map[uint16]uint16, len(ports))
|
||||
for _, port := range ports {
|
||||
internalToExternalPorts[port] = port
|
||||
}
|
||||
err = s.onNewPorts(ctx, internalToExternalPorts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handling new ports: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gosettings"
|
||||
@@ -17,7 +18,8 @@ type Settings struct {
|
||||
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
||||
ServerName string // needed for PIA
|
||||
CanPortForward bool // needed for PIA
|
||||
ListeningPort uint16
|
||||
ListeningPorts []uint16
|
||||
PortsCount uint16
|
||||
Username string // needed for PIA
|
||||
Password string // needed for PIA
|
||||
}
|
||||
@@ -31,7 +33,8 @@ func (s Settings) Copy() (copied Settings) {
|
||||
copied.Interface = s.Interface
|
||||
copied.ServerName = s.ServerName
|
||||
copied.CanPortForward = s.CanPortForward
|
||||
copied.ListeningPort = s.ListeningPort
|
||||
copied.ListeningPorts = gosettings.CopySlice(s.ListeningPorts)
|
||||
copied.PortsCount = s.PortsCount
|
||||
copied.Username = s.Username
|
||||
copied.Password = s.Password
|
||||
return copied
|
||||
@@ -46,7 +49,8 @@ func (s *Settings) OverrideWith(update Settings) {
|
||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
|
||||
s.ListeningPorts = gosettings.OverrideWithSlice(s.ListeningPorts, update.ListeningPorts)
|
||||
s.PortsCount = gosettings.OverrideWithComparable(s.PortsCount, update.PortsCount)
|
||||
s.Username = gosettings.OverrideWithComparable(s.Username, update.Username)
|
||||
s.Password = gosettings.OverrideWithComparable(s.Password, update.Password)
|
||||
}
|
||||
@@ -58,6 +62,10 @@ var (
|
||||
ErrPasswordNotSet = errors.New("password not set")
|
||||
ErrFilepathNotSet = errors.New("file path not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
ErrPortsCountZero = errors.New("ports count cannot be zero")
|
||||
ErrPortsCountTooHigh = errors.New("ports count too high")
|
||||
ErrListeningPortsLen = errors.New("listening ports length must be equal to ports count")
|
||||
ErrListeningPortZero = errors.New("listening port cannot be 0")
|
||||
)
|
||||
|
||||
func (s *Settings) Validate(forStartup bool) (err error) {
|
||||
@@ -78,7 +86,12 @@ func (s *Settings) Validate(forStartup bool) (err error) {
|
||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
||||
case s.Interface == "":
|
||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
||||
case s.PortForwarder.Name() == providers.PrivateInternetAccess:
|
||||
case s.PortsCount == 0:
|
||||
return fmt.Errorf("%w", ErrPortsCountZero)
|
||||
}
|
||||
|
||||
switch s.PortForwarder.Name() {
|
||||
case providers.PrivateInternetAccess:
|
||||
switch {
|
||||
case s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
@@ -87,6 +100,26 @@ func (s *Settings) Validate(forStartup bool) (err error) {
|
||||
case s.Password == "":
|
||||
return fmt.Errorf("%w", ErrPasswordNotSet)
|
||||
}
|
||||
case providers.Protonvpn:
|
||||
const maxPortsCount = 4
|
||||
if s.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
||||
}
|
||||
default:
|
||||
const maxPortsCount = 1
|
||||
if s.PortsCount > maxPortsCount {
|
||||
return fmt.Errorf("%w: %d > %d", ErrPortsCountTooHigh, s.PortsCount, maxPortsCount)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(s.ListeningPorts, []uint16{0}) {
|
||||
switch {
|
||||
case len(s.ListeningPorts) != int(s.PortsCount):
|
||||
return fmt.Errorf("%w: %d != %d", ErrListeningPortsLen, len(s.ListeningPorts), s.PortsCount)
|
||||
case slices.Contains(s.ListeningPorts, 0):
|
||||
return fmt.Errorf("%w: in %v", ErrListeningPortZero, s.ListeningPorts)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
@@ -42,8 +43,9 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
CanPortForward: s.settings.CanPortForward,
|
||||
Username: s.settings.Username,
|
||||
Password: s.settings.Password,
|
||||
PortsCount: s.settings.PortsCount,
|
||||
}
|
||||
ports, err := s.settings.PortForwarder.PortForward(ctx, obj)
|
||||
internalToExternalPorts, err := s.settings.PortForwarder.PortForward(ctx, obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
|
||||
}
|
||||
@@ -51,7 +53,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
err = s.onNewPorts(ctx, ports)
|
||||
err = s.onNewPorts(ctx, internalToExternalPorts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -86,36 +88,60 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
return runErrorCh, nil
|
||||
}
|
||||
|
||||
func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) {
|
||||
slices.Sort(ports)
|
||||
func (s *Service) onNewPorts(ctx context.Context, internalToExternalPorts map[uint16]uint16) (err error) {
|
||||
autoRedirectionNeeded := false
|
||||
externalToInternalPorts := make(map[uint16]uint16, len(internalToExternalPorts))
|
||||
for internal, external := range internalToExternalPorts {
|
||||
externalToInternalPorts[external] = internal
|
||||
if internal != external {
|
||||
autoRedirectionNeeded = true
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(portsToString(ports))
|
||||
externalPorts := slices.Collect(maps.Keys(externalToInternalPorts))
|
||||
slices.Sort(externalPorts)
|
||||
|
||||
for _, port := range ports {
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
s.logger.Info(portsToString(externalPorts))
|
||||
|
||||
userRedirectionEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0})
|
||||
for i, port := range externalPorts {
|
||||
internalPort := externalToInternalPorts[port]
|
||||
err = s.portAllower.SetAllowedPort(ctx, internalPort, s.settings.Interface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allowing port in firewall: %w", err)
|
||||
}
|
||||
|
||||
if s.settings.ListeningPort != 0 {
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting port in firewall: %w", err)
|
||||
}
|
||||
var sourcePort, destinationPort uint16
|
||||
switch {
|
||||
case userRedirectionEnabled: // precedence over auto redirection
|
||||
sourcePort = externalToInternalPorts[port]
|
||||
destinationPort = s.settings.ListeningPorts[i]
|
||||
case autoRedirectionNeeded:
|
||||
sourcePort = externalToInternalPorts[port]
|
||||
destinationPort = port
|
||||
default:
|
||||
// No redirection needed, source and destination ports are the same.
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, sourcePort, destinationPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting port %d to %d in firewall: %w",
|
||||
sourcePort, destinationPort, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(ports)
|
||||
err = s.writePortForwardedFile(externalPorts)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.ports = make([]uint16, len(ports))
|
||||
copy(s.ports, ports)
|
||||
s.ports = make([]uint16, len(internalToExternalPorts))
|
||||
copy(s.ports, externalPorts)
|
||||
|
||||
if s.settings.UpCommand != "" {
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface)
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, externalPorts, s.settings.Interface)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running up command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -38,13 +39,14 @@ func (s *Service) cleanup() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
redirectionWasEnabled := !slices.Equal(s.settings.ListeningPorts, []uint16{0})
|
||||
for _, port := range s.ports {
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocking previous port in firewall: %w", err)
|
||||
}
|
||||
|
||||
if s.settings.ListeningPort != 0 {
|
||||
if redirectionWasEnabled {
|
||||
ctx := context.Background()
|
||||
const listeningPort = 0 // 0 to clear the redirection
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, listeningPort)
|
||||
|
||||
Reference in New Issue
Block a user