mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
feat(server): PUT /v1/portforward route to set ports forwarded (#2392)
This commit is contained in:
committed by
Quentin McGaw
parent
199ad77ec9
commit
724cd3a15e
@@ -10,6 +10,7 @@ type Service interface {
|
||||
Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Stop() (err error)
|
||||
GetPortsForwarded() (ports []uint16)
|
||||
SetPortsForwarded(ctx context.Context, ports []uint16) (err error)
|
||||
}
|
||||
|
||||
type Routing interface {
|
||||
|
||||
@@ -2,6 +2,7 @@ package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -166,6 +167,16 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
|
||||
return l.service.GetPortsForwarded()
|
||||
}
|
||||
|
||||
var ErrServiceNotStarted = errors.New("port forwarding service not started")
|
||||
|
||||
func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
|
||||
if l.service == nil {
|
||||
return fmt.Errorf("%w", ErrServiceNotStarted)
|
||||
}
|
||||
|
||||
return l.service.SetPortsForwarded(l.runCtx, ports)
|
||||
}
|
||||
|
||||
func ptrTo[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -50,3 +52,29 @@ func (s *Service) GetPortsForwarded() (ports []uint16) {
|
||||
copy(ports, s.ports)
|
||||
return ports
|
||||
}
|
||||
|
||||
func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) {
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
slices.Sort(ports)
|
||||
if slices.Equal(s.ports, ports) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = s.cleanup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleaning up: %w", err)
|
||||
}
|
||||
|
||||
err = s.onNewPorts(ctx, ports)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handling new ports: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("updated: " + portsToString(s.ports))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
@@ -47,38 +48,12 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info(portsToString(ports))
|
||||
|
||||
for _, port := range ports {
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allowing port in firewall: %w", err)
|
||||
}
|
||||
|
||||
if s.settings.ListeningPort != 0 {
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(ports)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return nil, fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.portMutex.Lock()
|
||||
s.ports = ports
|
||||
s.portMutex.Unlock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
if s.settings.UpCommand != "" {
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running up command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
}
|
||||
err = s.onNewPorts(ctx, ports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||
@@ -101,6 +76,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
}
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
_ = s.cleanup()
|
||||
runError <- err
|
||||
}(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh)
|
||||
@@ -108,3 +85,42 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
|
||||
return runErrorCh, nil
|
||||
}
|
||||
|
||||
func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) {
|
||||
slices.Sort(ports)
|
||||
|
||||
s.logger.Info(portsToString(ports))
|
||||
|
||||
for _, port := range ports {
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allowing port in firewall: %w", err)
|
||||
}
|
||||
|
||||
if s.settings.ListeningPort != 0 {
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting port in firewall: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(ports)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.ports = make([]uint16, len(ports))
|
||||
copy(s.ports, ports)
|
||||
|
||||
if s.settings.UpCommand != "" {
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running up command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,9 +27,6 @@ func (s *Service) Stop() (err error) {
|
||||
}
|
||||
|
||||
func (s *Service) cleanup() (err error) {
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
if s.settings.DownCommand != "" {
|
||||
const downTimeout = 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
|
||||
|
||||
@@ -15,7 +15,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
authSettings auth.Settings,
|
||||
buildInfo models.BuildInformation,
|
||||
vpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter,
|
||||
pf PortForwarding,
|
||||
dnsLooper DNSLoop,
|
||||
updaterLooper UpdaterLooper,
|
||||
publicIPLooper PublicIPLoop,
|
||||
@@ -29,7 +29,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
dns := newDNSHandler(ctx, dnsLooper, logger)
|
||||
updater := newUpdaterHandler(ctx, updaterLooper, logger)
|
||||
publicip := newPublicIPHandler(publicIPLooper, logger)
|
||||
portForward := newPortForwardHandler(ctx, pfGetter, logger)
|
||||
portForward := newPortForwardHandler(ctx, pf, logger)
|
||||
|
||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper)
|
||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip, portForward)
|
||||
|
||||
@@ -21,8 +21,9 @@ type DNSLoop interface {
|
||||
GetStatus() (status models.LoopStatus)
|
||||
}
|
||||
|
||||
type PortForwardedGetter interface {
|
||||
type PortForwarding interface {
|
||||
GetPortsForwarded() (ports []uint16)
|
||||
SetPortsForwarded(ports []uint16) (err error)
|
||||
}
|
||||
|
||||
type PublicIPLoop interface {
|
||||
|
||||
@@ -156,6 +156,7 @@ var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
http.MethodGet + " /v1/portforward": {},
|
||||
http.MethodPut + " /v1/portforward": {},
|
||||
}
|
||||
|
||||
func (r Role) ToLinesNode() (node *gotree.Node) {
|
||||
|
||||
@@ -3,11 +3,12 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func newPortForwardHandler(ctx context.Context,
|
||||
portForward PortForwardedGetter, warner warner,
|
||||
portForward PortForwarding, warner warner,
|
||||
) http.Handler {
|
||||
return &portForwardHandler{
|
||||
ctx: ctx,
|
||||
@@ -18,7 +19,7 @@ func newPortForwardHandler(ctx context.Context,
|
||||
|
||||
type portForwardHandler struct {
|
||||
ctx context.Context //nolint:containedctx
|
||||
portForward PortForwardedGetter
|
||||
portForward PortForwarding
|
||||
warner warner
|
||||
}
|
||||
|
||||
@@ -26,6 +27,8 @@ func (h *portForwardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getPortForwarded(w)
|
||||
case http.MethodPut:
|
||||
h.setPortForwarded(w, r)
|
||||
default:
|
||||
errMethodNotSupported(w, r.Method)
|
||||
}
|
||||
@@ -50,3 +53,24 @@ func (h *portForwardHandler) getPortForwarded(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *portForwardHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) {
|
||||
var data portsWrapper
|
||||
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
err := decoder.Decode(&data)
|
||||
if err != nil {
|
||||
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
|
||||
http.Error(w, "failed setting forwarded ports", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.portForward.SetPortsForwarded(data.Ports)
|
||||
if err != nil {
|
||||
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
|
||||
http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
func New(ctx context.Context, settings settings.ControlServer, logger Logger,
|
||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter, dnsLooper DNSLoop,
|
||||
pf PortForwarding, dnsLooper DNSLoop,
|
||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||
ipv6Supported bool) (
|
||||
server *httpserver.Server, err error,
|
||||
@@ -25,7 +25,7 @@ func New(ctx context.Context, settings settings.ControlServer, logger Logger,
|
||||
}
|
||||
|
||||
handler, err := newHandler(ctx, logger, *settings.Log, authSettings, buildInfo,
|
||||
openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper,
|
||||
openvpnLooper, pf, dnsLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating handler: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user