fix(server/auth): return 404 or 405 depending on route

- Fix #3275
This commit is contained in:
Quentin McGaw
2026-04-07 19:41:05 +00:00
parent 11883aa830
commit 3b9c9b24bd
3 changed files with 69 additions and 42 deletions
@@ -3,6 +3,7 @@ package auth
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"slices"
) )
func New(settings Settings, debugLogger DebugLogger) ( func New(settings Settings, debugLogger DebugLogger) (
@@ -30,6 +31,17 @@ type authHandler struct {
} }
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
methods, ok := validRoutes[request.URL.Path]
if !ok {
h.logger.Debugf("url path %s is not a valid route", request.URL.Path)
http.Error(writer, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if !slices.Contains(methods, request.Method) {
h.logger.Debugf("method %s is not valid for url path %s", request.Method, request.URL.Path)
http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
route := request.Method + " " + request.URL.Path route := request.Method + " " + request.URL.Path
roles := h.routeToRoles[route] roles := h.routeToRoles[route]
if len(roles) == 0 { if len(roles) == 0 {
@@ -32,28 +32,28 @@ func Test_authHandler_ServeHTTP(t *testing.T) {
}, },
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl) logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b") logger.EXPECT().Debugf("url path %s is not a valid route", "/b")
return logger return logger
}, },
requestMethod: http.MethodGet, requestMethod: http.MethodGet,
requestPath: "/b", requestPath: "/b",
statusCode: http.StatusUnauthorized, statusCode: http.StatusNotFound,
responseBody: "Unauthorized\n", responseBody: "Not Found\n",
}, },
"authorized_none": { "authorized_none": {
settings: Settings{ settings: Settings{
Roles: []Role{ Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}}, {Name: "role1", Auth: AuthNone, Routes: []string{"GET /v1/portforward"}},
}, },
}, },
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl) logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("access to route %s authorized for role %s", logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /a", "role1") "GET /v1/portforward", "role1")
return logger return logger
}, },
requestMethod: http.MethodGet, requestMethod: http.MethodGet,
requestPath: "/a", requestPath: "/v1/portforward",
statusCode: http.StatusOK, statusCode: http.StatusOK,
}, },
} }
+51 -36
View File
@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strings"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/validate" "github.com/qdm12/gosettings/validate"
@@ -38,22 +39,27 @@ func (s *Settings) SetDefaultRole(jsonRole string) error {
return fmt.Errorf("validating default role: %w", err) return fmt.Errorf("validating default role: %w", err)
} }
authenticatedRoutes := make(map[string]struct{}, len(validRoutes)) maxRoutes := countValidRoutes()
authenticatedRoutes := make(map[string]struct{}, maxRoutes)
for _, role := range s.Roles { for _, role := range s.Roles {
for _, route := range role.Routes { for _, route := range role.Routes {
authenticatedRoutes[route] = struct{}{} authenticatedRoutes[route] = struct{}{}
} }
} }
if len(authenticatedRoutes) == len(validRoutes) { if len(authenticatedRoutes) == maxRoutes {
return nil return nil
} }
unauthenticatedRoutes := make([]string, 0, len(validRoutes)) var unauthenticatedRoutes []string
for route := range validRoutes { for urlPath, methods := range validRoutes {
_, authenticated := authenticatedRoutes[route] for _, method := range methods {
if !authenticated { route := method + " " + urlPath
unauthenticatedRoutes = append(unauthenticatedRoutes, route) _, authenticated := authenticatedRoutes[route]
if !authenticated {
unauthenticatedRoutes = append(unauthenticatedRoutes, route)
}
} }
} }
@@ -101,11 +107,12 @@ type Role struct {
} }
var ( var (
ErrMethodNotSupported = errors.New("authentication method not supported") ErrMethodNotSupported = errors.New("authentication method not supported")
ErrAPIKeyEmpty = errors.New("api key is empty") ErrAPIKeyEmpty = errors.New("api key is empty")
ErrBasicUsernameEmpty = errors.New("username is empty") ErrBasicUsernameEmpty = errors.New("username is empty")
ErrBasicPasswordEmpty = errors.New("password is empty") ErrBasicPasswordEmpty = errors.New("password is empty")
ErrRouteNotSupported = errors.New("route not supported by the control server") ErrRoutePathNotSupported = errors.New("route path not supported by the control server")
ErrRouteMethodNotSupported = errors.New("route method not supported for the path")
) )
func (r Role) Validate() (err error) { func (r Role) Validate() (err error) {
@@ -124,39 +131,47 @@ func (r Role) Validate() (err error) {
} }
for i, route := range r.Routes { for i, route := range r.Routes {
_, ok := validRoutes[route] const maxRouteFields = 2
parts := strings.SplitN(route, " ", maxRouteFields)
method, path := parts[0], parts[1]
methods, ok := validRoutes[path]
if !ok { if !ok {
return fmt.Errorf("route %d of %d: %w: %s", return fmt.Errorf("route %d of %d: %w: %s",
i+1, len(r.Routes), ErrRouteNotSupported, route) i+1, len(r.Routes), ErrRoutePathNotSupported, path)
} else if !slices.Contains(methods, method) {
return fmt.Errorf("route %d of %d: %w: %s for path %s",
i+1, len(r.Routes), ErrRouteMethodNotSupported, method, path)
} }
} }
return nil return nil
} }
// validRoutes maps URL paths to allowed HTTP methods.
// WARNING: do not mutate programmatically. // WARNING: do not mutate programmatically.
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals var validRoutes = map[string][]string{ //nolint:gochecknoglobals
http.MethodGet + " /openvpn/actions/restart": {}, "/openvpn/actions/restart": {http.MethodGet},
http.MethodGet + " /openvpn/portforwarded": {}, "/openvpn/portforwarded": {http.MethodGet},
http.MethodGet + " /openvpn/settings": {}, "/openvpn/settings": {http.MethodGet},
http.MethodGet + " /unbound/actions/restart": {}, "/unbound/actions/restart": {http.MethodGet},
http.MethodGet + " /updater/restart": {}, "/updater/restart": {http.MethodGet},
http.MethodGet + " /v1/version": {}, "/v1/version": {http.MethodGet},
http.MethodGet + " /v1/vpn/status": {}, "/v1/vpn/status": {http.MethodGet, http.MethodPut},
http.MethodPut + " /v1/vpn/status": {}, "/v1/vpn/settings": {http.MethodGet, http.MethodPut},
http.MethodGet + " /v1/vpn/settings": {}, "/v1/openvpn/status": {http.MethodGet, http.MethodPut},
http.MethodPut + " /v1/vpn/settings": {}, "/v1/openvpn/portforwarded": {http.MethodGet},
http.MethodGet + " /v1/openvpn/status": {}, "/v1/openvpn/settings": {http.MethodGet},
http.MethodPut + " /v1/openvpn/status": {}, "/v1/dns/status": {http.MethodGet, http.MethodPut},
http.MethodGet + " /v1/openvpn/portforwarded": {}, "/v1/updater/status": {http.MethodGet, http.MethodPut},
http.MethodGet + " /v1/openvpn/settings": {}, "/v1/publicip/ip": {http.MethodGet},
http.MethodGet + " /v1/dns/status": {}, "/v1/portforward": {http.MethodGet, http.MethodPut},
http.MethodPut + " /v1/dns/status": {}, }
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {}, func countValidRoutes() (count int) {
http.MethodGet + " /v1/publicip/ip": {}, for _, methods := range validRoutes {
http.MethodGet + " /v1/portforward": {}, count += len(methods)
http.MethodPut + " /v1/portforward": {}, }
return count
} }
func (r Role) ToLinesNode() (node *gotree.Node) { func (r Role) ToLinesNode() (node *gotree.Node) {