diff --git a/docs/en/documentation/configuration/authentication/_index.md b/docs/en/documentation/configuration/authentication/_index.md index 2ee36488a670..abe56ae353a6 100644 --- a/docs/en/documentation/configuration/authentication/_index.md +++ b/docs/en/documentation/configuration/authentication/_index.md @@ -6,19 +6,18 @@ description: > AuthServices represent services that handle authentication and authorization. --- -AuthServices represent services that handle authentication and authorization. It -can primarily be used by [Tools](../tools/_index.md) in two different ways: +AuthServices represent services that handle authentication and authorization. They support two distinct modes of operation: -- [**Authorized Invocation**][auth-invoke] is when a tool - is validated by the auth service before the call can be invoked. Toolbox - will reject any calls that fail to validate or have an invalid token. -- [**Authenticated Parameters**][auth-params] replace the value of a parameter - with a field from an [OIDC][openid-claims] claim. Toolbox will automatically - resolve the ID token provided by the client and replace the parameter in the - tool call. +### 1. Toolbox Native Authorization +Used for specific tools to enforce authorization or resolve parameters: +- [**Authorized Invocation**][auth-invoke]: A tool is validated by the auth service before it can be invoked. Toolbox will reject any calls that fail to validate or have an invalid token. +- [**Authenticated Parameters**][auth-params]: Replaces the value of a parameter with a field from an [OIDC][openid-claims] claim. Toolbox will automatically resolve the ID token provided by the client and replace the parameter in the tool call. + +### 2. MCP Authorization +Used to secure the entire MCP server. The Model Context Protocol supports [MCP Authorization](https://modelcontextprotocol.io/docs/tutorials/security/authorization) to secure interactions between clients and servers. When enabled, all MCP endpoints require a valid token, and you can enforce granular tool-level scope authorization. **Note that this mode is currently only supported when using the `generic` auth service type.** [openid-claims]: https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims -[auth-invoke]: ../tools/_index.md#authorized-invocations +[auth-invoke]: ../tools/_index.md#authorized-invocations-toolbox-native-authorization [auth-params]: ../tools/_index.md#authenticated-parameters ## Example @@ -48,13 +47,9 @@ Use environment variable replacement with the format ${ENV_NAME} instead of hardcoding your secrets into the configuration file. {{< /notice >}} -After you've configured an `authService` you'll, need to reference it in the -configuration for each tool that should use it: - -- **Authorized Invocations** for authorizing a tool call, [use the - `authRequired` field in a tool config][auth-invoke] -- **Authenticated Parameters** for using the value from a OIDC claim, [use the - `authService` field in a parameter config][auth-params] +After you've configured an `authService`, you can use it: +- For **Toolbox Native Authorization** by referencing it in your tool configuration (using `authRequired` or `authService` in parameters). +- For **MCP Authorization** by setting `mcpEnabled: true` in the auth service configuration to secure the entire server. ## Specifying ID Tokens from Clients diff --git a/docs/en/documentation/configuration/authentication/generic.md b/docs/en/documentation/configuration/authentication/generic.md index dd1b3268d300..a8ea5ece078e 100644 --- a/docs/en/documentation/configuration/authentication/generic.md +++ b/docs/en/documentation/configuration/authentication/generic.md @@ -124,6 +124,29 @@ scopesRequired: - write ``` +#### Tool-Level Scopes + +When using MCP Authorization (with `mcpEnabled: true` in the auth service), you can enforce granular tool-level scope authorization by specifying the `scopesRequired` field in the tool configuration. + +This ensures that a client can only invoke the tool if their authorization token contains all the specified scopes. + +```yaml +kind: tool +name: update_flight_status +type: postgres-sql +source: my-pg-instance +statement: | + UPDATE flights SET status = $1 WHERE flight_number = $2 +description: Update flight status +authRequired: + - my-generic-auth +scopesRequired: + - execute:sql + - write:flights +``` + +If a client attempts to invoke this tool without the required scopes, the server will return an HTTP 403 Forbidden response with a `WWW-Authenticate` header challenge indicating the missing scopes, as per the MCP Auth specification. + {{< notice tip >}} Use environment variable replacement with the format ${ENV_NAME} instead of hardcoding your secrets into the configuration file. {{< /notice >}} diff --git a/docs/en/documentation/configuration/tools/_index.md b/docs/en/documentation/configuration/tools/_index.md index 080e8d63d3f9..36a46082913c 100644 --- a/docs/en/documentation/configuration/tools/_index.md +++ b/docs/en/documentation/configuration/tools/_index.md @@ -260,7 +260,13 @@ templateParameters: | excludedValues | []string | false | Input value will be checked against this field. Regex is also supported. | | items | parameter object | true (if array) | Specify a Parameter object for the type of the values in the array (string only). | -## Authorized Invocations +## Tool-Level Scopes (MCP Authorization) + +The Model Context Protocol supports [MCP Authorization](https://modelcontextprotocol.io/docs/tutorials/security/authorization) to secure interactions between clients and servers. When using MCP Authorization in Toolbox, you can enforce granular tool-level scope authorization by specifying the `scopesRequired` field in the tool configuration. + +For detailed information on how to configure this and examples, please see the [Generic OIDC Auth](../authentication/generic.md#tool-level-scopes) documentation. + +## Authorized Invocations (Toolbox Native Authorization) You can require an authorization check for any Tool invocation request by specifying an `authRequired` field. Specify a list of @@ -279,7 +285,6 @@ authRequired: - other-auth-service ``` - ## Tool Annotations Tool annotations provide semantic metadata that helps MCP clients understand tool diff --git a/internal/auth/generic/generic.go b/internal/auth/generic/generic.go index 7385a6176493..f9747044a298 100644 --- a/internal/auth/generic/generic.go +++ b/internal/auth/generic/generic.go @@ -227,15 +227,15 @@ type MCPAuthError struct { func (e *MCPAuthError) Error() string { return e.Message } // ValidateMCPAuth handles MCP auth token validation -func (a AuthService) ValidateMCPAuth(ctx context.Context, h http.Header) error { +func (a AuthService) ValidateMCPAuth(ctx context.Context, h http.Header) (map[string]any, error) { tokenString := h.Get("Authorization") if tokenString == "" { - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "missing access token", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "missing access token", ScopesRequired: a.ScopesRequired} } headerParts := strings.Split(tokenString, " ") if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" { - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "authorization header must be in the format 'Bearer '", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "authorization header must be in the format 'Bearer '", ScopesRequired: a.ScopesRequired} } tokenStr := headerParts[1] @@ -251,40 +251,44 @@ func isJWTFormat(token string) bool { } // validateJwtToken validates a JWT token locally -func (a AuthService) validateJwtToken(ctx context.Context, tokenStr string) error { +func (a AuthService) validateJwtToken(ctx context.Context, tokenStr string) (map[string]any, error) { token, err := jwt.Parse(tokenStr, a.kf.Keyfunc) if err != nil || !token.Valid { - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid or expired token", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid or expired token", ScopesRequired: a.ScopesRequired} } claims, ok := token.Claims.(jwt.MapClaims) if !ok { - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid JWT claims format", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid JWT claims format", ScopesRequired: a.ScopesRequired} } // Validate audience aud, err := claims.GetAudience() if err != nil { - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "could not parse audience from token", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "could not parse audience from token", ScopesRequired: a.ScopesRequired} } scopeClaim, _ := claims["scope"].(string) - return a.validateClaims(ctx, aud, scopeClaim) + err = a.validateClaims(ctx, aud, scopeClaim) + if err != nil { + return nil, err + } + return claims, nil } // validateOpaqueToken validates an opaque token by calling the introspection endpoint -func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) error { +func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) (map[string]any, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { - return fmt.Errorf("failed to get logger from context: %w", err) + return nil, fmt.Errorf("failed to get logger from context: %w", err) } introspectionURL := a.introspectionURL if introspectionURL == "" { introspectionURL, err = url.JoinPath(a.AuthorizationServer, "introspect") if err != nil { - return fmt.Errorf("failed to construct introspection URL: %w", err) + return nil, fmt.Errorf("failed to construct introspection URL: %w", err) } } @@ -293,7 +297,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e req, err := http.NewRequestWithContext(ctx, "POST", introspectionURL, strings.NewReader(data.Encode())) if err != nil { - return fmt.Errorf("failed to create introspection request: %w", err) + return nil, fmt.Errorf("failed to create introspection request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") @@ -302,18 +306,18 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e resp, err := a.client.Do(req) if err != nil { logger.ErrorContext(ctx, "failed to call introspection endpoint: %v", err) - return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to call introspection endpoint: %v", err), ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to call introspection endpoint: %v", err), ScopesRequired: a.ScopesRequired} } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.WarnContext(ctx, "introspection failed with status: %d", resp.StatusCode) - return &MCPAuthError{Code: http.StatusUnauthorized, Message: fmt.Sprintf("introspection failed with status: %d", resp.StatusCode), ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: fmt.Sprintf("introspection failed with status: %d", resp.StatusCode), ScopesRequired: a.ScopesRequired} } body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - return fmt.Errorf("failed to read introspection response: %w", err) + return nil, fmt.Errorf("failed to read introspection response: %w", err) } var introspectResp struct { @@ -324,19 +328,19 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e } if err := json.Unmarshal(body, &introspectResp); err != nil { - return fmt.Errorf("failed to parse introspection response: %w", err) + return nil, fmt.Errorf("failed to parse introspection response: %w", err) } if !introspectResp.Active { logger.InfoContext(ctx, "token is not active") - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "token is not active", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "token is not active", ScopesRequired: a.ScopesRequired} } // Verify expiration (with 1 minute leeway) const leeway = 60 if introspectResp.Exp > 0 && time.Now().Unix() > (introspectResp.Exp+leeway) { logger.WarnContext(ctx, "token has expired: exp=%d, now=%d", introspectResp.Exp, time.Now().Unix()) - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "token has expired", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "token has expired", ScopesRequired: a.ScopesRequired} } // Extract audience @@ -351,11 +355,21 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e aud = audArr } else { logger.WarnContext(ctx, "failed to parse aud claim in introspection response") - return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid aud claim", ScopesRequired: a.ScopesRequired} + return nil, &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid aud claim", ScopesRequired: a.ScopesRequired} } } - return a.validateClaims(ctx, aud, introspectResp.Scope) + err = a.validateClaims(ctx, aud, introspectResp.Scope) + if err != nil { + return nil, err + } + claims := map[string]any{ + "active": introspectResp.Active, + "scope": introspectResp.Scope, + "aud": aud, + "exp": introspectResp.Exp, + } + return claims, nil } // validateClaims validates the audience and scopes of a token diff --git a/internal/auth/generic/generic_test.go b/internal/auth/generic/generic_test.go index d68b383409d5..74ac6f8eb88d 100644 --- a/internal/auth/generic/generic_test.go +++ b/internal/auth/generic/generic_test.go @@ -394,7 +394,7 @@ func TestValidateMCPAuth_Opaque(t *testing.T) { header := http.Header{} header.Set("Authorization", "Bearer "+tc.token) - err = genericAuth.ValidateMCPAuth(ctx, header) + _, err = genericAuth.ValidateMCPAuth(ctx, header) if tc.wantError { if err == nil { @@ -486,7 +486,7 @@ func TestValidateJwtToken(t *testing.T) { t.Fatalf("failed to create logger: %v", err) } ctx := util.WithLogger(context.Background(), logger) - err = genericAuth.validateJwtToken(ctx, tc.token) + _, err = genericAuth.validateJwtToken(ctx, tc.token) if tc.wantError { if err == nil { t.Fatalf("expected error, got nil") @@ -649,7 +649,7 @@ func TestValidateOpaqueToken(t *testing.T) { } ctx := util.WithLogger(context.Background(), logger) - err = genericAuth.validateOpaqueToken(ctx, tc.token) + _, err = genericAuth.validateOpaqueToken(ctx, tc.token) if tc.wantError { if err == nil { diff --git a/internal/server/config.go b/internal/server/config.go index efa947cefeac..5c55ec492266 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -316,6 +316,21 @@ func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) r["authRequired"] = []string{} } + // Parse scopesRequired if present + if rawScopes, ok := r["scopesRequired"]; ok { + if scopesList, ok := rawScopes.([]any); ok { + var scopes []string + for _, s := range scopesList { + if str, ok := s.(string); ok { + scopes = append(scopes, str) + } + } + r["scopesRequired"] = scopes + } else { + return nil, fmt.Errorf("scopesRequired must be a list of strings") + } + } + // validify parameter references if rawParams, ok := r["parameters"]; ok { if paramsList, ok := rawParams.([]any); ok { diff --git a/internal/server/mcp.go b/internal/server/mcp.go index ffcb5e57ad95..7332ed530849 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "net/http" + "strings" "sync" "time" @@ -333,6 +334,13 @@ func mcpRouter(s *Server) (chi.Router, error) { r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest")) r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) + // Inject logger into ctx + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := util.WithLogger(r.Context(), s.logger) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) r.Use(mcpAuthMiddleware(s)) r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) @@ -463,7 +471,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") ctx := r.Context() - ctx = util.WithLogger(ctx, s.logger) // Read body first so we can extract trace context body, err := io.ReadAll(r.Body) @@ -578,6 +585,21 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { if errors.As(err, &clientServerErr) { w.WriteHeader(clientServerErr.Code) } + var mcpErr *generic.MCPAuthError + if errors.As(err, &mcpErr) { + switch mcpErr.Code { + case http.StatusForbidden: + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error="insufficient_scope", scope="%s", resource_metadata="%s", error_description="%s"`, strings.Join(mcpErr.ScopesRequired, " "), s.toolboxUrl+"/.well-known/oauth-protected-resource", mcpErr.Message)) + w.WriteHeader(http.StatusForbidden) + case http.StatusUnauthorized: + scopesArg := "" + if len(mcpErr.ScopesRequired) > 0 { + scopesArg = fmt.Sprintf(`, scope="%s"`, strings.Join(mcpErr.ScopesRequired, " ")) + } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"%s`, s.toolboxUrl+"/.well-known/oauth-protected-resource", scopesArg)) + w.WriteHeader(http.StatusUnauthorized) + } + } } } diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 68fd5c3bfd57..5cc58af48aae 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -21,8 +21,11 @@ import ( "errors" "fmt" "net/http" + "slices" + "strings" "time" + "github.com/googleapis/mcp-toolbox/internal/auth/generic" "github.com/googleapis/mcp-toolbox/internal/prompts" "github.com/googleapis/mcp-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/mcp-toolbox/internal/server/resources" @@ -197,6 +200,50 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } logger.DebugContext(ctx, "tool invocation authorized") + // Find MCP enabled auth service + var mcpSvcName string + for _, aS := range authServices { + cfg := aS.ToConfig() + if genCfg, ok := cfg.(generic.Config); ok && genCfg.McpEnabled { + mcpSvcName = aS.GetName() + break + } + } + + toolScopes := tool.GetScopesRequired() + if mcpSvcName != "" && len(toolScopes) > 0 { + claims := util.AuthTokenClaimsFromContext(ctx) + if claims == nil { + err = &generic.MCPAuthError{ + Code: http.StatusForbidden, + Message: "missing claims for MCP authorization", + ScopesRequired: toolScopes, + } + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + } + + scopeClaim, _ := claims["scope"].(string) + tokenScopes := strings.Split(scopeClaim, " ") + + // Check if all required scopes are present in the token + missing := false + for _, ts := range toolScopes { + if !slices.Contains(tokenScopes, ts) { + missing = true + break + } + } + + if missing { + err = &generic.MCPAuthError{ + Code: http.StatusForbidden, + Message: "insufficient scopes for this tool", + ScopesRequired: toolScopes, + } + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + } + } + params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth) if err != nil { err = fmt.Errorf("provided parameters were invalid: %w", err) diff --git a/internal/server/server.go b/internal/server/server.go index f0ccad606515..1e24e886156b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -509,7 +509,8 @@ func mcpAuthMiddleware(s *Server) func(http.Handler) http.Handler { return } - if err := mcpSvc.ValidateMCPAuth(r.Context(), r.Header); err != nil { + claims, err := mcpSvc.ValidateMCPAuth(r.Context(), r.Header) + if err != nil { var mcpErr *generic.MCPAuthError if errors.As(err, &mcpErr) { switch mcpErr.Code { @@ -527,8 +528,13 @@ func mcpAuthMiddleware(s *Server) func(http.Handler) http.Handler { return } } + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + ctx := util.WithAuthTokenClaims(r.Context(), claims) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) }) } diff --git a/internal/util/util.go b/internal/util/util.go index d651aea0b2da..d1efc7390ece 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -214,3 +214,18 @@ func GenAIMetricAttrsFromContext(ctx context.Context) *GenAIMetricAttrs { } return nil } + +const authTokenClaimsKey contextKey = "authTokenClaims" + +// WithAuthTokenClaims adds auth token claims into the context as a value +func WithAuthTokenClaims(ctx context.Context, claims map[string]any) context.Context { + return context.WithValue(ctx, authTokenClaimsKey, claims) +} + +// AuthTokenClaimsFromContext retrieves the auth token claims from context +func AuthTokenClaimsFromContext(ctx context.Context) map[string]any { + if claims, ok := ctx.Value(authTokenClaimsKey).(map[string]any); ok { + return claims + } + return nil +} diff --git a/tests/auth/auth_integration_test.go b/tests/auth/auth_integration_test.go index 990c769b3c93..73ca3e780e3d 100644 --- a/tests/auth/auth_integration_test.go +++ b/tests/auth/auth_integration_test.go @@ -15,6 +15,7 @@ package auth import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -69,7 +70,7 @@ func TestMcpAuth(t *testing.T) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]interface{}{ "active": true, - "scope": "read:files", + "scope": "read:files execute:sql", "aud": "test-audience", "exp": time.Now().Add(time.Hour).Unix(), }) @@ -80,7 +81,12 @@ func TestMcpAuth(t *testing.T) { defer jwksServer.Close() toolsFile := map[string]any{ - "sources": map[string]any{}, + "sources": map[string]any{ + "my-sqlite": map[string]any{ + "type": "sqlite", + "database": ":memory:", + }, + }, "authServices": map[string]any{ "my-generic-auth": map[string]any{ "type": "generic", @@ -90,7 +96,14 @@ func TestMcpAuth(t *testing.T) { "mcpEnabled": true, }, }, - "tools": map[string]any{}, + "tools": map[string]any{ + "my-tool": map[string]any{ + "type": "sqlite-execute-sql", + "source": "my-sqlite", + "description": "Execute SQL on SQLite", + "scopesRequired": []string{"execute:sql"}, + }, + }, } args := []string{"--enable-api", "--toolbox-url=http://127.0.0.1:5000"} cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) @@ -107,7 +120,8 @@ func TestMcpAuth(t *testing.T) { t.Fatalf("toolbox didn't start successfully: %s", err) } - api := "http://127.0.0.1:5000/mcp/sse" + apiSSE := "http://127.0.0.1:5000/mcp/sse" + apiMCP := "http://127.0.0.1:5000/mcp" // Generate invalid token (wrong scopes) invalidToken := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ @@ -119,7 +133,7 @@ func TestMcpAuth(t *testing.T) { invalidToken.Header["kid"] = "test-key-id" invalidSignedString, _ := invalidToken.SignedString(privateKey) - // Generate valid token (correct scopes) + // Generate valid token validToken := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": "test-audience", "scope": "read:files", @@ -129,9 +143,32 @@ func TestMcpAuth(t *testing.T) { validToken.Header["kid"] = "test-key-id" validSignedString, _ := validToken.SignedString(privateKey) + // Generate token with only read:files scope + tokenOnlyRead := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": "test-audience", + "scope": "read:files", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenOnlyRead.Header["kid"] = "test-key-id" + tokenOnlyReadStr, _ := tokenOnlyRead.SignedString(privateKey) + + // Generate token with BOTH scopes + tokenBoth := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": "test-audience", + "scope": "read:files execute:sql", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenBoth.Header["kid"] = "test-key-id" + tokenBothStr, _ := tokenBoth.SignedString(privateKey) + tests := []struct { name string token string + method string + url string + body []byte wantStatusCode int checkWWWAuth func(t *testing.T, authHeader string) }{ @@ -165,14 +202,77 @@ func TestMcpAuth(t *testing.T) { token: "this-is-an-opaque-token", wantStatusCode: http.StatusOK, }, + { + name: "403 Forbidden with insufficient tool scopes", + token: tokenOnlyReadStr, + method: http.MethodPost, + url: apiMCP, + body: func() []byte { + b, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "my-tool", + "arguments": map[string]any{ + "sql": "SELECT 1;", + }, + }, + }) + return b + }(), + wantStatusCode: http.StatusForbidden, + checkWWWAuth: func(t *testing.T, authHeader string) { + if !strings.Contains(authHeader, `error="insufficient_scope"`) || !strings.Contains(authHeader, `scope="execute:sql"`) { + t.Fatalf("expected WWW-Authenticate header to contain error and tool scope, got: %s", authHeader) + } + }, + }, + { + name: "200 OK with sufficient tool scopes", + token: tokenBothStr, + method: http.MethodPost, + url: apiMCP, + body: func() []byte { + b, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "my-tool", + "arguments": map[string]any{ + "sql": "SELECT 1;", + }, + }, + }) + return b + }(), + wantStatusCode: http.StatusOK, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, api, nil) + method := tc.method + if method == "" { + method = http.MethodGet + } + url := tc.url + if url == "" { + url = apiSSE + } + var body io.Reader + if tc.body != nil { + body = bytes.NewBuffer(tc.body) + } + req, _ := http.NewRequest(method, url, body) if tc.token != "" { req.Header.Add("Authorization", "Bearer "+tc.token) } + if method == http.MethodPost { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("MCP-Protocol-Version", "2025-11-25") + } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("unable to send request: %s", err)