Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/auth/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e

body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return fmt.Errorf("failed to read introspection response: %w", err)
return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to read introspection response: %v", err), ScopesRequired: a.ScopesRequired}
}

var introspectResp struct {
Expand All @@ -324,7 +324,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
}

if err := json.Unmarshal(body, &introspectResp); err != nil {
return fmt.Errorf("failed to parse introspection response: %w", err)
return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to parse introspection response: %v", err), ScopesRequired: a.ScopesRequired}
}

if !introspectResp.Active {
Expand Down
7 changes: 7 additions & 0 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ func mcpRouter(s *Server) (chi.Router, error) {
r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest"))
r.Use(middleware.StripSlashes)
r.Use(render.SetContentType(render.ContentTypeJSON))
// Inject logger into ctx
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := util.WithLogger(r.Context(), s.logger)
next.ServeHTTP(w, r.WithContext(ctx))
})
})
r.Use(mcpAuthMiddleware(s))

r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
Expand Down
4 changes: 4 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ func mcpAuthMiddleware(s *Server) func(http.Handler) http.Handler {
return
}
}
// Fail closed on unexpected errors
s.logger.ErrorContext(r.Context(), "unexpected error during MCP auth validation", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

next.ServeHTTP(w, r)
Expand Down
172 changes: 172 additions & 0 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server_test

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand All @@ -25,6 +26,7 @@ import (
"reflect"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/googleapis/mcp-toolbox/internal/auth"
Expand Down Expand Up @@ -547,3 +549,173 @@ func TestLegacyAPIGone(t *testing.T) {
t.Errorf("expected response body to contain %q, got %q", want, string(body))
}
}

func TestMCPAuthMiddleware(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Setup telemetry and logging
otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer func() {
if err := otelShutdown(ctx); err != nil {
t.Fatalf("unexpected error shutting down otel: %s", err)
}
}()

testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx = util.WithLogger(ctx, testLogger)

instrumentation, err := telemetry.CreateTelemetryInstrumentation("0.0.0")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx = util.WithInstrumentation(ctx, instrumentation)

// Setup mock introspection server
var mockResponse map[string]any
var mockStatus int
var mockRawResponse string

mockOIDC := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/openid-configuration" {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"issuer": "http://%s", "jwks_uri": "http://%s/jwks", "introspection_endpoint": "http://%s/introspect"}`, r.Host, r.Host, r.Host)
return
}
if r.URL.Path == "/jwks" {
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"keys": []}`)
return
}
if r.URL.Path == "/introspect" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(mockStatus)
if mockRawResponse != "" {
_, _ = w.Write([]byte(mockRawResponse))
} else {
_ = json.NewEncoder(w).Encode(mockResponse)
}
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockOIDC.Close()

// Configure the server
addr, port := "127.0.0.1", 5004
cfg := server.ServerConfig{
Version: "0.0.0",
Address: addr,
Port: port,
ToolboxUrl: "https://my-toolbox.example.com",
AllowedHosts: []string{"*"},
AuthServiceConfigs: map[string]auth.AuthServiceConfig{
"generic1": generic.Config{
Name: "generic1",
Type: generic.AuthServiceType,
McpEnabled: true,
AuthorizationServer: mockOIDC.URL,
ScopesRequired: []string{"mcp"},
},
},
}

// Initialize and start the server
s, err := server.NewServer(ctx, cfg)
if err != nil {
t.Fatalf("unable to initialize server: %v", err)
}

if err := s.Listen(ctx); err != nil {
t.Fatalf("unable to start server: %v", err)
}

errCh := make(chan error)
go func() {
defer close(errCh)
if err := s.Serve(ctx); err != nil && err != http.ErrServerClosed {
errCh <- err
}
}()
defer func() {
if err := s.Shutdown(ctx); err != nil {
t.Errorf("failed to cleanly shutdown server: %v", err)
}
}()

tests := []struct {
name string
token string
setupMock func()
wantStatusCode int
}{
{
name: "valid opaque token",
token: "valid-token",
setupMock: func() {
mockStatus = http.StatusOK
mockResponse = map[string]any{
"active": true,
"scope": "mcp",
"aud": "test-audience",
"exp": time.Now().Add(time.Hour).Unix(),
}
mockRawResponse = ""
},
wantStatusCode: http.StatusOK,
},
{
name: "insufficient scope",
token: "bad-scope-token",
setupMock: func() {
mockStatus = http.StatusOK
mockResponse = map[string]any{
"active": true,
"scope": "wrong-scope",
"aud": "test-audience",
"exp": time.Now().Add(time.Hour).Unix(),
}
mockRawResponse = ""
},
wantStatusCode: http.StatusForbidden,
},
{
name: "malformed introspection",
token: "any-token",
setupMock: func() {
mockStatus = http.StatusOK
mockRawResponse = "{invalid json}"
},
wantStatusCode: http.StatusInternalServerError,
},
}

url := fmt.Sprintf("http://%s:%d/mcp", addr, port)

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()

reqBody := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(reqBody))
req.Header.Set("Authorization", "Bearer "+tc.token)
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tc.wantStatusCode {
t.Errorf("expected status %d, got %d", tc.wantStatusCode, resp.StatusCode)
}
})
}
}
Loading