From 3b9c9b24bd35d7201f8b3d151616d6fea9abba5e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Apr 2026 19:41:05 +0000 Subject: [PATCH] fix(server/auth): return 404 or 405 depending on route - Fix #3275 --- .../server/middlewares/auth/middleware.go | 12 +++ .../middlewares/auth/middleware_test.go | 12 +-- internal/server/middlewares/auth/settings.go | 87 +++++++++++-------- 3 files changed, 69 insertions(+), 42 deletions(-) diff --git a/internal/server/middlewares/auth/middleware.go b/internal/server/middlewares/auth/middleware.go index b2bf5d06..1aa4c28e 100644 --- a/internal/server/middlewares/auth/middleware.go +++ b/internal/server/middlewares/auth/middleware.go @@ -3,6 +3,7 @@ package auth import ( "fmt" "net/http" + "slices" ) func New(settings Settings, debugLogger DebugLogger) ( @@ -30,6 +31,17 @@ type authHandler struct { } func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + methods, ok := validRoutes[request.URL.Path] + if !ok { + h.logger.Debugf("url path %s is not a valid route", request.URL.Path) + http.Error(writer, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } else if !slices.Contains(methods, request.Method) { + h.logger.Debugf("method %s is not valid for url path %s", request.Method, request.URL.Path) + http.Error(writer, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + route := request.Method + " " + request.URL.Path roles := h.routeToRoles[route] if len(roles) == 0 { diff --git a/internal/server/middlewares/auth/middleware_test.go b/internal/server/middlewares/auth/middleware_test.go index ae731da2..badd768b 100644 --- a/internal/server/middlewares/auth/middleware_test.go +++ b/internal/server/middlewares/auth/middleware_test.go @@ -32,28 +32,28 @@ func Test_authHandler_ServeHTTP(t *testing.T) { }, makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { logger := NewMockDebugLogger(ctrl) - logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b") + logger.EXPECT().Debugf("url path %s is not a valid route", "/b") return logger }, requestMethod: http.MethodGet, requestPath: "/b", - statusCode: http.StatusUnauthorized, - responseBody: "Unauthorized\n", + statusCode: http.StatusNotFound, + responseBody: "Not Found\n", }, "authorized_none": { settings: Settings{ Roles: []Role{ - {Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}}, + {Name: "role1", Auth: AuthNone, Routes: []string{"GET /v1/portforward"}}, }, }, makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { logger := NewMockDebugLogger(ctrl) logger.EXPECT().Debugf("access to route %s authorized for role %s", - "GET /a", "role1") + "GET /v1/portforward", "role1") return logger }, requestMethod: http.MethodGet, - requestPath: "/a", + requestPath: "/v1/portforward", statusCode: http.StatusOK, }, } diff --git a/internal/server/middlewares/auth/settings.go b/internal/server/middlewares/auth/settings.go index c854e790..eaaf45ef 100644 --- a/internal/server/middlewares/auth/settings.go +++ b/internal/server/middlewares/auth/settings.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "slices" + "strings" "github.com/qdm12/gosettings" "github.com/qdm12/gosettings/validate" @@ -38,22 +39,27 @@ func (s *Settings) SetDefaultRole(jsonRole string) error { return fmt.Errorf("validating default role: %w", err) } - authenticatedRoutes := make(map[string]struct{}, len(validRoutes)) + maxRoutes := countValidRoutes() + + authenticatedRoutes := make(map[string]struct{}, maxRoutes) for _, role := range s.Roles { for _, route := range role.Routes { authenticatedRoutes[route] = struct{}{} } } - if len(authenticatedRoutes) == len(validRoutes) { + if len(authenticatedRoutes) == maxRoutes { return nil } - unauthenticatedRoutes := make([]string, 0, len(validRoutes)) - for route := range validRoutes { - _, authenticated := authenticatedRoutes[route] - if !authenticated { - unauthenticatedRoutes = append(unauthenticatedRoutes, route) + var unauthenticatedRoutes []string + for urlPath, methods := range validRoutes { + for _, method := range methods { + route := method + " " + urlPath + _, authenticated := authenticatedRoutes[route] + if !authenticated { + unauthenticatedRoutes = append(unauthenticatedRoutes, route) + } } } @@ -101,11 +107,12 @@ type Role struct { } var ( - ErrMethodNotSupported = errors.New("authentication method not supported") - ErrAPIKeyEmpty = errors.New("api key is empty") - ErrBasicUsernameEmpty = errors.New("username is empty") - ErrBasicPasswordEmpty = errors.New("password is empty") - ErrRouteNotSupported = errors.New("route not supported by the control server") + ErrMethodNotSupported = errors.New("authentication method not supported") + ErrAPIKeyEmpty = errors.New("api key is empty") + ErrBasicUsernameEmpty = errors.New("username is empty") + ErrBasicPasswordEmpty = errors.New("password is empty") + ErrRoutePathNotSupported = errors.New("route path not supported by the control server") + ErrRouteMethodNotSupported = errors.New("route method not supported for the path") ) func (r Role) Validate() (err error) { @@ -124,39 +131,47 @@ func (r Role) Validate() (err error) { } for i, route := range r.Routes { - _, ok := validRoutes[route] + const maxRouteFields = 2 + parts := strings.SplitN(route, " ", maxRouteFields) + method, path := parts[0], parts[1] + methods, ok := validRoutes[path] if !ok { return fmt.Errorf("route %d of %d: %w: %s", - i+1, len(r.Routes), ErrRouteNotSupported, route) + i+1, len(r.Routes), ErrRoutePathNotSupported, path) + } else if !slices.Contains(methods, method) { + return fmt.Errorf("route %d of %d: %w: %s for path %s", + i+1, len(r.Routes), ErrRouteMethodNotSupported, method, path) } } return nil } +// validRoutes maps URL paths to allowed HTTP methods. // WARNING: do not mutate programmatically. -var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals - http.MethodGet + " /openvpn/actions/restart": {}, - http.MethodGet + " /openvpn/portforwarded": {}, - http.MethodGet + " /openvpn/settings": {}, - http.MethodGet + " /unbound/actions/restart": {}, - http.MethodGet + " /updater/restart": {}, - http.MethodGet + " /v1/version": {}, - http.MethodGet + " /v1/vpn/status": {}, - http.MethodPut + " /v1/vpn/status": {}, - http.MethodGet + " /v1/vpn/settings": {}, - http.MethodPut + " /v1/vpn/settings": {}, - http.MethodGet + " /v1/openvpn/status": {}, - http.MethodPut + " /v1/openvpn/status": {}, - http.MethodGet + " /v1/openvpn/portforwarded": {}, - http.MethodGet + " /v1/openvpn/settings": {}, - http.MethodGet + " /v1/dns/status": {}, - http.MethodPut + " /v1/dns/status": {}, - http.MethodGet + " /v1/updater/status": {}, - http.MethodPut + " /v1/updater/status": {}, - http.MethodGet + " /v1/publicip/ip": {}, - http.MethodGet + " /v1/portforward": {}, - http.MethodPut + " /v1/portforward": {}, +var validRoutes = map[string][]string{ //nolint:gochecknoglobals + "/openvpn/actions/restart": {http.MethodGet}, + "/openvpn/portforwarded": {http.MethodGet}, + "/openvpn/settings": {http.MethodGet}, + "/unbound/actions/restart": {http.MethodGet}, + "/updater/restart": {http.MethodGet}, + "/v1/version": {http.MethodGet}, + "/v1/vpn/status": {http.MethodGet, http.MethodPut}, + "/v1/vpn/settings": {http.MethodGet, http.MethodPut}, + "/v1/openvpn/status": {http.MethodGet, http.MethodPut}, + "/v1/openvpn/portforwarded": {http.MethodGet}, + "/v1/openvpn/settings": {http.MethodGet}, + "/v1/dns/status": {http.MethodGet, http.MethodPut}, + "/v1/updater/status": {http.MethodGet, http.MethodPut}, + "/v1/publicip/ip": {http.MethodGet}, + "/v1/portforward": {http.MethodGet, http.MethodPut}, +} + +func countValidRoutes() (count int) { + for _, methods := range validRoutes { + count += len(methods) + } + return count } func (r Role) ToLinesNode() (node *gotree.Node) {