feat(server): PUT /v1/portforward route to set ports forwarded (#2392)

This commit is contained in:
Rubyn Angelo Stark
2026-03-07 18:06:03 +01:00
committed by Quentin McGaw
parent 199ad77ec9
commit 724cd3a15e
10 changed files with 119 additions and 40 deletions
+1
View File
@@ -10,6 +10,7 @@ type Service interface {
Start(ctx context.Context) (runError <-chan error, err error) Start(ctx context.Context) (runError <-chan error, err error)
Stop() (err error) Stop() (err error)
GetPortsForwarded() (ports []uint16) GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ctx context.Context, ports []uint16) (err error)
} }
type Routing interface { type Routing interface {
+11
View File
@@ -2,6 +2,7 @@ package portforward
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@@ -166,6 +167,16 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
return l.service.GetPortsForwarded() return l.service.GetPortsForwarded()
} }
var ErrServiceNotStarted = errors.New("port forwarding service not started")
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
if l.service == nil {
return fmt.Errorf("%w", ErrServiceNotStarted)
}
return l.service.SetPortsForwarded(l.runCtx, ports)
}
func ptrTo[T any](value T) *T { func ptrTo[T any](value T) *T {
return &value return &value
} }
+28
View File
@@ -2,7 +2,9 @@ package service
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"slices"
"sync" "sync"
) )
@@ -50,3 +52,29 @@ func (s *Service) GetPortsForwarded() (ports []uint16) {
copy(ports, s.ports) copy(ports, s.ports)
return ports return ports
} }
func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
s.portMutex.Lock()
defer s.portMutex.Unlock()
slices.Sort(ports)
if slices.Equal(s.ports, ports) {
return nil
}
err = s.cleanup()
if err != nil {
return fmt.Errorf("cleaning up: %w", err)
}
err = s.onNewPorts(ctx, ports)
if err != nil {
return fmt.Errorf("handling new ports: %w", err)
}
s.logger.Info("updated: " + portsToString(s.ports))
return nil
}
+46 -30
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
@@ -47,38 +48,12 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return nil, fmt.Errorf("port forwarding for the first time: %w", err) return nil, fmt.Errorf("port forwarding for the first time: %w", err)
} }
s.logger.Info(portsToString(ports))
for _, port := range ports {
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
return nil, fmt.Errorf("allowing port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
}
}
}
err = s.writePortForwardedFile(ports)
if err != nil {
_ = s.cleanup()
return nil, fmt.Errorf("writing port file: %w", err)
}
s.portMutex.Lock() s.portMutex.Lock()
s.ports = ports defer s.portMutex.Unlock()
s.portMutex.Unlock()
if s.settings.UpCommand != "" { err = s.onNewPorts(ctx, ports)
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface) if err != nil {
if err != nil { return nil, err
err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error())
}
} }
keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
@@ -101,6 +76,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
} }
s.startStopMutex.Lock() s.startStopMutex.Lock()
defer s.startStopMutex.Unlock() defer s.startStopMutex.Unlock()
s.portMutex.Lock()
defer s.portMutex.Unlock()
_ = s.cleanup() _ = s.cleanup()
runError <- err runError <- err
}(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh) }(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh)
@@ -108,3 +85,42 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return runErrorCh, nil return runErrorCh, nil
} }
func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) {
slices.Sort(ports)
s.logger.Info(portsToString(ports))
for _, port := range ports {
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
return fmt.Errorf("allowing port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return fmt.Errorf("redirecting port in firewall: %w", err)
}
}
}
err = s.writePortForwardedFile(ports)
if err != nil {
_ = s.cleanup()
return fmt.Errorf("writing port file: %w", err)
}
s.ports = make([]uint16, len(ports))
copy(s.ports, ports)
if s.settings.UpCommand != "" {
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface)
if err != nil {
err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error())
}
}
return nil
}
-3
View File
@@ -27,9 +27,6 @@ func (s *Service) Stop() (err error) {
} }
func (s *Service) cleanup() (err error) { func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()
if s.settings.DownCommand != "" { if s.settings.DownCommand != "" {
const downTimeout = 60 * time.Second const downTimeout = 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), downTimeout) ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
+2 -2
View File
@@ -15,7 +15,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
authSettings auth.Settings, authSettings auth.Settings,
buildInfo models.BuildInformation, buildInfo models.BuildInformation,
vpnLooper VPNLooper, vpnLooper VPNLooper,
pfGetter PortForwardedGetter, pf PortForwarding,
dnsLooper DNSLoop, dnsLooper DNSLoop,
updaterLooper UpdaterLooper, updaterLooper UpdaterLooper,
publicIPLooper PublicIPLoop, publicIPLooper PublicIPLoop,
@@ -29,7 +29,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
dns := newDNSHandler(ctx, dnsLooper, logger) dns := newDNSHandler(ctx, dnsLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger)
portForward := newPortForwardHandler(ctx, pfGetter, logger) portForward := newPortForwardHandler(ctx, pf, logger)
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper) handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip, portForward) handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip, portForward)
+2 -1
View File
@@ -21,8 +21,9 @@ type DNSLoop interface {
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
} }
type PortForwardedGetter interface { type PortForwarding interface {
GetPortsForwarded() (ports []uint16) GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ports []uint16) (err error)
} }
type PublicIPLoop interface { type PublicIPLoop interface {
@@ -156,6 +156,7 @@ var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodPut + " /v1/updater/status": {}, http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {}, http.MethodGet + " /v1/publicip/ip": {},
http.MethodGet + " /v1/portforward": {}, http.MethodGet + " /v1/portforward": {},
http.MethodPut + " /v1/portforward": {},
} }
func (r Role) ToLinesNode() (node *gotree.Node) { func (r Role) ToLinesNode() (node *gotree.Node) {
+26 -2
View File
@@ -3,11 +3,12 @@ package server
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
func newPortForwardHandler(ctx context.Context, func newPortForwardHandler(ctx context.Context,
portForward PortForwardedGetter, warner warner, portForward PortForwarding, warner warner,
) http.Handler { ) http.Handler {
return &portForwardHandler{ return &portForwardHandler{
ctx: ctx, ctx: ctx,
@@ -18,7 +19,7 @@ func newPortForwardHandler(ctx context.Context,
type portForwardHandler struct { type portForwardHandler struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
portForward PortForwardedGetter portForward PortForwarding
warner warner warner warner
} }
@@ -26,6 +27,8 @@ func (h *portForwardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
h.getPortForwarded(w) h.getPortForwarded(w)
case http.MethodPut:
h.setPortForwarded(w, r)
default: default:
errMethodNotSupported(w, r.Method) errMethodNotSupported(w, r.Method)
} }
@@ -50,3 +53,24 @@ func (h *portForwardHandler) getPortForwarded(w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
} }
func (h *portForwardHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) {
var data portsWrapper
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&data)
if err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusBadRequest)
return
}
err = h.portForward.SetPortsForwarded(data.Ports)
if err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
+2 -2
View File
@@ -14,7 +14,7 @@ import (
func New(ctx context.Context, settings settings.ControlServer, logger Logger, func New(ctx context.Context, settings settings.ControlServer, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, dnsLooper DNSLoop, pf PortForwarding, dnsLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) ( ipv6Supported bool) (
server *httpserver.Server, err error, server *httpserver.Server, err error,
@@ -25,7 +25,7 @@ func New(ctx context.Context, settings settings.ControlServer, logger Logger,
} }
handler, err := newHandler(ctx, logger, *settings.Log, authSettings, buildInfo, handler, err := newHandler(ctx, logger, *settings.Log, authSettings, buildInfo,
openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper, openvpnLooper, pf, dnsLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported) storage, ipv6Supported)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating handler: %w", err) return nil, fmt.Errorf("creating handler: %w", err)