mirror of
https://github.com/qdm12/gluetun.git
synced 2026-05-06 20:10:11 +02:00
@@ -3,6 +3,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(settings Settings, debugLogger DebugLogger) (
|
func New(settings Settings, debugLogger DebugLogger) (
|
||||||
@@ -30,6 +31,17 @@ type authHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
methods, ok := validRoutes[request.URL.Path]
|
||||||
|
if !ok {
|
||||||
|
h.logger.Debugf("url path %s is not a valid route", request.URL.Path)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
} else if !slices.Contains(methods, request.Method) {
|
||||||
|
h.logger.Debugf("method %s is not valid for url path %s", request.Method, request.URL.Path)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
route := request.Method + " " + request.URL.Path
|
route := request.Method + " " + request.URL.Path
|
||||||
roles := h.routeToRoles[route]
|
roles := h.routeToRoles[route]
|
||||||
if len(roles) == 0 {
|
if len(roles) == 0 {
|
||||||
|
|||||||
@@ -32,28 +32,28 @@ func Test_authHandler_ServeHTTP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
logger := NewMockDebugLogger(ctrl)
|
logger := NewMockDebugLogger(ctrl)
|
||||||
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
logger.EXPECT().Debugf("url path %s is not a valid route", "/b")
|
||||||
return logger
|
return logger
|
||||||
},
|
},
|
||||||
requestMethod: http.MethodGet,
|
requestMethod: http.MethodGet,
|
||||||
requestPath: "/b",
|
requestPath: "/b",
|
||||||
statusCode: http.StatusUnauthorized,
|
statusCode: http.StatusNotFound,
|
||||||
responseBody: "Unauthorized\n",
|
responseBody: "Not Found\n",
|
||||||
},
|
},
|
||||||
"authorized_none": {
|
"authorized_none": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
Roles: []Role{
|
Roles: []Role{
|
||||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /v1/portforward"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
logger := NewMockDebugLogger(ctrl)
|
logger := NewMockDebugLogger(ctrl)
|
||||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||||
"GET /a", "role1")
|
"GET /v1/portforward", "role1")
|
||||||
return logger
|
return logger
|
||||||
},
|
},
|
||||||
requestMethod: http.MethodGet,
|
requestMethod: http.MethodGet,
|
||||||
requestPath: "/a",
|
requestPath: "/v1/portforward",
|
||||||
statusCode: http.StatusOK,
|
statusCode: http.StatusOK,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gosettings"
|
"github.com/qdm12/gosettings"
|
||||||
"github.com/qdm12/gosettings/validate"
|
"github.com/qdm12/gosettings/validate"
|
||||||
@@ -38,22 +39,27 @@ func (s *Settings) SetDefaultRole(jsonRole string) error {
|
|||||||
return fmt.Errorf("validating default role: %w", err)
|
return fmt.Errorf("validating default role: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticatedRoutes := make(map[string]struct{}, len(validRoutes))
|
maxRoutes := countValidRoutes()
|
||||||
|
|
||||||
|
authenticatedRoutes := make(map[string]struct{}, maxRoutes)
|
||||||
for _, role := range s.Roles {
|
for _, role := range s.Roles {
|
||||||
for _, route := range role.Routes {
|
for _, route := range role.Routes {
|
||||||
authenticatedRoutes[route] = struct{}{}
|
authenticatedRoutes[route] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(authenticatedRoutes) == len(validRoutes) {
|
if len(authenticatedRoutes) == maxRoutes {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
unauthenticatedRoutes := make([]string, 0, len(validRoutes))
|
var unauthenticatedRoutes []string
|
||||||
for route := range validRoutes {
|
for urlPath, methods := range validRoutes {
|
||||||
_, authenticated := authenticatedRoutes[route]
|
for _, method := range methods {
|
||||||
if !authenticated {
|
route := method + " " + urlPath
|
||||||
unauthenticatedRoutes = append(unauthenticatedRoutes, route)
|
_, authenticated := authenticatedRoutes[route]
|
||||||
|
if !authenticated {
|
||||||
|
unauthenticatedRoutes = append(unauthenticatedRoutes, route)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,11 +107,12 @@ type Role struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMethodNotSupported = errors.New("authentication method not supported")
|
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||||
ErrAPIKeyEmpty = errors.New("api key is empty")
|
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||||
ErrBasicUsernameEmpty = errors.New("username is empty")
|
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||||
ErrBasicPasswordEmpty = errors.New("password is empty")
|
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||||
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
ErrRoutePathNotSupported = errors.New("route path not supported by the control server")
|
||||||
|
ErrRouteMethodNotSupported = errors.New("route method not supported for the path")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r Role) Validate() (err error) {
|
func (r Role) Validate() (err error) {
|
||||||
@@ -124,39 +131,47 @@ func (r Role) Validate() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, route := range r.Routes {
|
for i, route := range r.Routes {
|
||||||
_, ok := validRoutes[route]
|
const maxRouteFields = 2
|
||||||
|
parts := strings.SplitN(route, " ", maxRouteFields)
|
||||||
|
method, path := parts[0], parts[1]
|
||||||
|
methods, ok := validRoutes[path]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("route %d of %d: %w: %s",
|
return fmt.Errorf("route %d of %d: %w: %s",
|
||||||
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
i+1, len(r.Routes), ErrRoutePathNotSupported, path)
|
||||||
|
} else if !slices.Contains(methods, method) {
|
||||||
|
return fmt.Errorf("route %d of %d: %w: %s for path %s",
|
||||||
|
i+1, len(r.Routes), ErrRouteMethodNotSupported, method, path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validRoutes maps URL paths to allowed HTTP methods.
|
||||||
// WARNING: do not mutate programmatically.
|
// WARNING: do not mutate programmatically.
|
||||||
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
var validRoutes = map[string][]string{ //nolint:gochecknoglobals
|
||||||
http.MethodGet + " /openvpn/actions/restart": {},
|
"/openvpn/actions/restart": {http.MethodGet},
|
||||||
http.MethodGet + " /openvpn/portforwarded": {},
|
"/openvpn/portforwarded": {http.MethodGet},
|
||||||
http.MethodGet + " /openvpn/settings": {},
|
"/openvpn/settings": {http.MethodGet},
|
||||||
http.MethodGet + " /unbound/actions/restart": {},
|
"/unbound/actions/restart": {http.MethodGet},
|
||||||
http.MethodGet + " /updater/restart": {},
|
"/updater/restart": {http.MethodGet},
|
||||||
http.MethodGet + " /v1/version": {},
|
"/v1/version": {http.MethodGet},
|
||||||
http.MethodGet + " /v1/vpn/status": {},
|
"/v1/vpn/status": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodPut + " /v1/vpn/status": {},
|
"/v1/vpn/settings": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodGet + " /v1/vpn/settings": {},
|
"/v1/openvpn/status": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodPut + " /v1/vpn/settings": {},
|
"/v1/openvpn/portforwarded": {http.MethodGet},
|
||||||
http.MethodGet + " /v1/openvpn/status": {},
|
"/v1/openvpn/settings": {http.MethodGet},
|
||||||
http.MethodPut + " /v1/openvpn/status": {},
|
"/v1/dns/status": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
"/v1/updater/status": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodGet + " /v1/openvpn/settings": {},
|
"/v1/publicip/ip": {http.MethodGet},
|
||||||
http.MethodGet + " /v1/dns/status": {},
|
"/v1/portforward": {http.MethodGet, http.MethodPut},
|
||||||
http.MethodPut + " /v1/dns/status": {},
|
}
|
||||||
http.MethodGet + " /v1/updater/status": {},
|
|
||||||
http.MethodPut + " /v1/updater/status": {},
|
func countValidRoutes() (count int) {
|
||||||
http.MethodGet + " /v1/publicip/ip": {},
|
for _, methods := range validRoutes {
|
||||||
http.MethodGet + " /v1/portforward": {},
|
count += len(methods)
|
||||||
http.MethodPut + " /v1/portforward": {},
|
}
|
||||||
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Role) ToLinesNode() (node *gotree.Node) {
|
func (r Role) ToLinesNode() (node *gotree.Node) {
|
||||||
|
|||||||
Reference in New Issue
Block a user