From 724cd3a15e02ada37f30e4caacdb19d75626cec3 Mon Sep 17 00:00:00 2001 From: Rubyn Angelo Stark <54439198+jagaimoworks@users.noreply.github.com> Date: Sat, 7 Mar 2026 18:06:03 +0100 Subject: [PATCH] feat(server): PUT `/v1/portforward` route to set ports forwarded (#2392) --- internal/portforward/interfaces.go | 1 + internal/portforward/loop.go | 11 +++ internal/portforward/service/service.go | 28 ++++++++ internal/portforward/service/start.go | 76 ++++++++++++-------- internal/portforward/service/stop.go | 3 - internal/server/handler.go | 4 +- internal/server/interfaces.go | 3 +- internal/server/middlewares/auth/settings.go | 1 + internal/server/portforward.go | 28 +++++++- internal/server/server.go | 4 +- 10 files changed, 119 insertions(+), 40 deletions(-) diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index 93277b48..68da0046 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -10,6 +10,7 @@ type Service interface { Start(ctx context.Context) (runError <-chan error, err error) Stop() (err error) GetPortsForwarded() (ports []uint16) + SetPortsForwarded(ctx context.Context, ports []uint16) (err error) } type Routing interface { diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 8f943145..2e19a4be 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -2,6 +2,7 @@ package portforward import ( "context" + "errors" "fmt" "net/http" "sync" @@ -166,6 +167,16 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) { return l.service.GetPortsForwarded() } +var ErrServiceNotStarted = errors.New("port forwarding service not started") + +func (l *Loop) SetPortsForwarded(ports []uint16) (err error) { + if l.service == nil { + return fmt.Errorf("%w", ErrServiceNotStarted) + } + + return l.service.SetPortsForwarded(l.runCtx, ports) +} + func ptrTo[T any](value T) *T { return &value } diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index 579b3739..e178c16d 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -2,7 +2,9 @@ package service import ( "context" + "fmt" "net/http" + "slices" "sync" ) @@ -50,3 +52,29 @@ func (s *Service) GetPortsForwarded() (ports []uint16) { copy(ports, s.ports) return ports } + +func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) { + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() + s.portMutex.Lock() + defer s.portMutex.Unlock() + + slices.Sort(ports) + if slices.Equal(s.ports, ports) { + return nil + } + + err = s.cleanup() + if err != nil { + return fmt.Errorf("cleaning up: %w", err) + } + + err = s.onNewPorts(ctx, ports) + if err != nil { + return fmt.Errorf("handling new ports: %w", err) + } + + s.logger.Info("updated: " + portsToString(s.ports)) + + return nil +} diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index ff3b5248..68e54db4 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "slices" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/provider/utils" @@ -47,38 +48,12 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) return nil, fmt.Errorf("port forwarding for the first time: %w", err) } - s.logger.Info(portsToString(ports)) - - for _, port := range ports { - err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) - if err != nil { - return nil, fmt.Errorf("allowing port in firewall: %w", err) - } - - if s.settings.ListeningPort != 0 { - err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) - if err != nil { - return nil, fmt.Errorf("redirecting port in firewall: %w", err) - } - } - } - - err = s.writePortForwardedFile(ports) - if err != nil { - _ = s.cleanup() - return nil, fmt.Errorf("writing port file: %w", err) - } - s.portMutex.Lock() - s.ports = ports - s.portMutex.Unlock() + defer s.portMutex.Unlock() - if s.settings.UpCommand != "" { - err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface) - if err != nil { - err = fmt.Errorf("running up command: %w", err) - s.logger.Error(err.Error()) - } + err = s.onNewPorts(ctx, ports) + if err != nil { + return nil, err } keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) @@ -101,6 +76,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) } s.startStopMutex.Lock() defer s.startStopMutex.Unlock() + s.portMutex.Lock() + defer s.portMutex.Unlock() _ = s.cleanup() runError <- err }(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh) @@ -108,3 +85,42 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) return runErrorCh, nil } + +func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) { + slices.Sort(ports) + + s.logger.Info(portsToString(ports)) + + for _, port := range ports { + err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) + if err != nil { + return fmt.Errorf("allowing port in firewall: %w", err) + } + + if s.settings.ListeningPort != 0 { + err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) + if err != nil { + return fmt.Errorf("redirecting port in firewall: %w", err) + } + } + } + + err = s.writePortForwardedFile(ports) + if err != nil { + _ = s.cleanup() + return fmt.Errorf("writing port file: %w", err) + } + + s.ports = make([]uint16, len(ports)) + copy(s.ports, ports) + + if s.settings.UpCommand != "" { + err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface) + if err != nil { + err = fmt.Errorf("running up command: %w", err) + s.logger.Error(err.Error()) + } + } + + return nil +} diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go index 8e86dfc9..e74f904f 100644 --- a/internal/portforward/service/stop.go +++ b/internal/portforward/service/stop.go @@ -27,9 +27,6 @@ func (s *Service) Stop() (err error) { } func (s *Service) cleanup() (err error) { - s.portMutex.Lock() - defer s.portMutex.Unlock() - if s.settings.DownCommand != "" { const downTimeout = 60 * time.Second ctx, cancel := context.WithTimeout(context.Background(), downTimeout) diff --git a/internal/server/handler.go b/internal/server/handler.go index b276940f..888d7c74 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -15,7 +15,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool, authSettings auth.Settings, buildInfo models.BuildInformation, vpnLooper VPNLooper, - pfGetter PortForwardedGetter, + pf PortForwarding, dnsLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, @@ -29,7 +29,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool, dns := newDNSHandler(ctx, dnsLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger) - portForward := newPortForwardHandler(ctx, pfGetter, logger) + portForward := newPortForwardHandler(ctx, pf, logger) handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper) handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip, portForward) diff --git a/internal/server/interfaces.go b/internal/server/interfaces.go index 5b470251..b49d2837 100644 --- a/internal/server/interfaces.go +++ b/internal/server/interfaces.go @@ -21,8 +21,9 @@ type DNSLoop interface { GetStatus() (status models.LoopStatus) } -type PortForwardedGetter interface { +type PortForwarding interface { GetPortsForwarded() (ports []uint16) + SetPortsForwarded(ports []uint16) (err error) } type PublicIPLoop interface { diff --git a/internal/server/middlewares/auth/settings.go b/internal/server/middlewares/auth/settings.go index 75b48fde..c854e790 100644 --- a/internal/server/middlewares/auth/settings.go +++ b/internal/server/middlewares/auth/settings.go @@ -156,6 +156,7 @@ var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals http.MethodPut + " /v1/updater/status": {}, http.MethodGet + " /v1/publicip/ip": {}, http.MethodGet + " /v1/portforward": {}, + http.MethodPut + " /v1/portforward": {}, } func (r Role) ToLinesNode() (node *gotree.Node) { diff --git a/internal/server/portforward.go b/internal/server/portforward.go index b751efa0..d2f7b392 100644 --- a/internal/server/portforward.go +++ b/internal/server/portforward.go @@ -3,11 +3,12 @@ package server import ( "context" "encoding/json" + "fmt" "net/http" ) func newPortForwardHandler(ctx context.Context, - portForward PortForwardedGetter, warner warner, + portForward PortForwarding, warner warner, ) http.Handler { return &portForwardHandler{ ctx: ctx, @@ -18,7 +19,7 @@ func newPortForwardHandler(ctx context.Context, type portForwardHandler struct { ctx context.Context //nolint:containedctx - portForward PortForwardedGetter + portForward PortForwarding warner warner } @@ -26,6 +27,8 @@ func (h *portForwardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: h.getPortForwarded(w) + case http.MethodPut: + h.setPortForwarded(w, r) default: errMethodNotSupported(w, r.Method) } @@ -50,3 +53,24 @@ func (h *portForwardHandler) getPortForwarded(w http.ResponseWriter) { w.WriteHeader(http.StatusInternalServerError) } } + +func (h *portForwardHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) { + var data portsWrapper + + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&data) + if err != nil { + h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err)) + http.Error(w, "failed setting forwarded ports", http.StatusBadRequest) + return + } + + err = h.portForward.SetPortsForwarded(data.Ports) + if err != nil { + h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err)) + http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/internal/server/server.go b/internal/server/server.go index 55b41548..f3629678 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,7 +14,7 @@ import ( func New(ctx context.Context, settings settings.ControlServer, logger Logger, buildInfo models.BuildInformation, openvpnLooper VPNLooper, - pfGetter PortForwardedGetter, dnsLooper DNSLoop, + pf PortForwarding, dnsLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage, ipv6Supported bool) ( server *httpserver.Server, err error, @@ -25,7 +25,7 @@ func New(ctx context.Context, settings settings.ControlServer, logger Logger, } handler, err := newHandler(ctx, logger, *settings.Log, authSettings, buildInfo, - openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper, + openvpnLooper, pf, dnsLooper, updaterLooper, publicIPLooper, storage, ipv6Supported) if err != nil { return nil, fmt.Errorf("creating handler: %w", err)