diff --git a/internal/auth/generic/generic.go b/internal/auth/generic/generic.go index 7385a6176493..bf3d22ec8eb7 100644 --- a/internal/auth/generic/generic.go +++ b/internal/auth/generic/generic.go @@ -313,7 +313,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - return fmt.Errorf("failed to read introspection response: %w", err) + return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to read introspection response: %v", err), ScopesRequired: a.ScopesRequired} } var introspectResp struct { @@ -324,7 +324,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e } if err := json.Unmarshal(body, &introspectResp); err != nil { - return fmt.Errorf("failed to parse introspection response: %w", err) + return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to parse introspection response: %v", err), ScopesRequired: a.ScopesRequired} } if !introspectResp.Active { diff --git a/internal/server/mcp.go b/internal/server/mcp.go index ffcb5e57ad95..fe38b4c2624e 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -333,6 +333,13 @@ func mcpRouter(s *Server) (chi.Router, error) { r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest")) r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) + // Inject logger into ctx + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := util.WithLogger(r.Context(), s.logger) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) r.Use(mcpAuthMiddleware(s)) r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) diff --git a/internal/server/server.go b/internal/server/server.go index f0ccad606515..2a10214ddfa8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -527,6 +527,10 @@ func mcpAuthMiddleware(s *Server) func(http.Handler) http.Handler { return } } + // Fail closed on unexpected errors + s.logger.ErrorContext(r.Context(), "unexpected error during MCP auth validation", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return } next.ServeHTTP(w, r) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ae0510b77675..ef33a1aa7936 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -15,6 +15,7 @@ package server_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -25,6 +26,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/googleapis/mcp-toolbox/internal/auth" @@ -547,3 +549,173 @@ func TestLegacyAPIGone(t *testing.T) { t.Errorf("expected response body to contain %q, got %q", want, string(body)) } } + +func TestMCPAuthMiddleware(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup telemetry and logging + otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer func() { + if err := otelShutdown(ctx); err != nil { + t.Fatalf("unexpected error shutting down otel: %s", err) + } + }() + + testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithLogger(ctx, testLogger) + + instrumentation, err := telemetry.CreateTelemetryInstrumentation("0.0.0") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithInstrumentation(ctx, instrumentation) + + // Setup mock introspection server + var mockResponse map[string]any + var mockStatus int + var mockRawResponse string + + mockOIDC := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"issuer": "http://%s", "jwks_uri": "http://%s/jwks", "introspection_endpoint": "http://%s/introspect"}`, r.Host, r.Host, r.Host) + return + } + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"keys": []}`) + return + } + if r.URL.Path == "/introspect" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(mockStatus) + if mockRawResponse != "" { + _, _ = w.Write([]byte(mockRawResponse)) + } else { + _ = json.NewEncoder(w).Encode(mockResponse) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockOIDC.Close() + + // Configure the server + addr, port := "127.0.0.1", 5004 + cfg := server.ServerConfig{ + Version: "0.0.0", + Address: addr, + Port: port, + ToolboxUrl: "https://my-toolbox.example.com", + AllowedHosts: []string{"*"}, + AuthServiceConfigs: map[string]auth.AuthServiceConfig{ + "generic1": generic.Config{ + Name: "generic1", + Type: generic.AuthServiceType, + McpEnabled: true, + AuthorizationServer: mockOIDC.URL, + ScopesRequired: []string{"mcp"}, + }, + }, + } + + // Initialize and start the server + s, err := server.NewServer(ctx, cfg) + if err != nil { + t.Fatalf("unable to initialize server: %v", err) + } + + if err := s.Listen(ctx); err != nil { + t.Fatalf("unable to start server: %v", err) + } + + errCh := make(chan error) + go func() { + defer close(errCh) + if err := s.Serve(ctx); err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + defer func() { + if err := s.Shutdown(ctx); err != nil { + t.Errorf("failed to cleanly shutdown server: %v", err) + } + }() + + tests := []struct { + name string + token string + setupMock func() + wantStatusCode int + }{ + { + name: "valid opaque token", + token: "valid-token", + setupMock: func() { + mockStatus = http.StatusOK + mockResponse = map[string]any{ + "active": true, + "scope": "mcp", + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + } + mockRawResponse = "" + }, + wantStatusCode: http.StatusOK, + }, + { + name: "insufficient scope", + token: "bad-scope-token", + setupMock: func() { + mockStatus = http.StatusOK + mockResponse = map[string]any{ + "active": true, + "scope": "wrong-scope", + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + } + mockRawResponse = "" + }, + wantStatusCode: http.StatusForbidden, + }, + { + name: "malformed introspection", + token: "any-token", + setupMock: func() { + mockStatus = http.StatusOK + mockRawResponse = "{invalid json}" + }, + wantStatusCode: http.StatusInternalServerError, + }, + } + + url := fmt.Sprintf("http://%s:%d/mcp", addr, port) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.setupMock() + + reqBody := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(reqBody)) + req.Header.Set("Authorization", "Bearer "+tc.token) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatusCode { + t.Errorf("expected status %d, got %d", tc.wantStatusCode, resp.StatusCode) + } + }) + } +}