From edb3e0c5efe2cdd63403646192ad7e39ed71efc8 Mon Sep 17 00:00:00 2001 From: sjhddh Date: Sun, 12 Apr 2026 19:13:28 +0200 Subject: [PATCH 1/2] fix: enforce toolset/promptset boundary on tools/call and prompts/get tools/call and prompts/get resolved tools and prompts via the global resourceMgr without verifying membership in the current toolset or promptset. This allowed clients connected to a scoped toolset to invoke any tool by name, bypassing toolset access boundaries (IDOR). Add ContainsTool/ContainsPrompt membership checks and enforce them in all four MCP protocol versions before delegating to resourceMgr. Return the same "does not exist" error on boundary violation as on a truly missing tool/prompt so clients cannot distinguish the two cases. Fixes #2755 --- internal/prompts/promptsets.go | 10 +++ internal/prompts/promptsets_test.go | 64 +++++++++++++++++++ internal/server/mcp/v20241105/method.go | 20 ++++-- internal/server/mcp/v20250326/method.go | 20 ++++-- internal/server/mcp/v20250618/method.go | 20 ++++-- internal/server/mcp/v20251125/method.go | 20 ++++-- internal/tools/toolsets.go | 10 +++ internal/tools/toolsets_test.go | 85 +++++++++++++++++++++++++ 8 files changed, 233 insertions(+), 16 deletions(-) create mode 100644 internal/tools/toolsets_test.go diff --git a/internal/prompts/promptsets.go b/internal/prompts/promptsets.go index 2d3cf6315246..a147ca9c8917 100644 --- a/internal/prompts/promptsets.go +++ b/internal/prompts/promptsets.go @@ -36,6 +36,16 @@ func (p Promptset) ToConfig() PromptsetConfig { return p.PromptsetConfig } +// ContainsPrompt reports whether the promptset includes a prompt with the given name. +func (p Promptset) ContainsPrompt(name string) bool { + for _, n := range p.PromptNames { + if n == name { + return true + } + } + return false +} + type PromptsetManifest struct { ServerVersion string `json:"serverVersion"` PromptsManifest map[string]Manifest `json:"prompts"` diff --git a/internal/prompts/promptsets_test.go b/internal/prompts/promptsets_test.go index 170120de5e1c..8ed1900cb58c 100644 --- a/internal/prompts/promptsets_test.go +++ b/internal/prompts/promptsets_test.go @@ -65,6 +65,70 @@ func newMockPrompt(name, desc string) prompts.Prompt { } } +func TestPromptset_ContainsPrompt(t *testing.T) { + t.Parallel() + + promptset := prompts.Promptset{ + PromptsetConfig: prompts.PromptsetConfig{ + Name: "test-promptset", + PromptNames: []string{"greet", "summarize"}, + }, + } + + tests := []struct { + name string + promptName string + want bool + }{ + { + name: "prompt exists in promptset", + promptName: "greet", + want: true, + }, + { + name: "another prompt exists in promptset", + promptName: "summarize", + want: true, + }, + { + name: "prompt not in promptset", + promptName: "admin_prompt", + want: false, + }, + { + name: "empty prompt name", + promptName: "", + want: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := promptset.ContainsPrompt(tc.promptName) + if got != tc.want { + t.Errorf("ContainsPrompt(%q) = %v, want %v", tc.promptName, got, tc.want) + } + }) + } +} + +func TestPromptset_ContainsPrompt_EmptyPromptset(t *testing.T) { + t.Parallel() + + promptset := prompts.Promptset{ + PromptsetConfig: prompts.PromptsetConfig{ + Name: "empty-promptset", + PromptNames: []string{}, + }, + } + + if promptset.ContainsPrompt("anything") { + t.Error("ContainsPrompt should return false for empty promptset") + } +} + func TestPromptsetConfig_Initialize(t *testing.T) { t.Parallel() diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index efcbab26f64c..b92b0fb841c5 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -339,7 +345,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -361,6 +367,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt with name %q does not exist", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 1d4292f38467..a24e00a57ff7 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -338,7 +344,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -360,6 +366,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt with name %q does not exist", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 529bd90e3870..1cd34520af78 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt with name %q does not exist", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 68fd5c3bfd57..d336443cd999 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt with name %q does not exist", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/tools/toolsets.go b/internal/tools/toolsets.go index b429ef5b19df..008ba844e5c6 100644 --- a/internal/tools/toolsets.go +++ b/internal/tools/toolsets.go @@ -35,6 +35,16 @@ func (t Toolset) ToConfig() ToolsetConfig { return t.ToolsetConfig } +// ContainsTool reports whether the toolset includes a tool with the given name. +func (t Toolset) ContainsTool(name string) bool { + for _, n := range t.ToolNames { + if n == name { + return true + } + } + return false +} + type ToolsetManifest struct { ServerVersion string `json:"serverVersion"` ToolsManifest map[string]Manifest `json:"tools"` diff --git a/internal/tools/toolsets_test.go b/internal/tools/toolsets_test.go new file mode 100644 index 000000000000..e72e49d3e878 --- /dev/null +++ b/internal/tools/toolsets_test.go @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools_test + +import ( + "testing" + + "github.com/googleapis/mcp-toolbox/internal/tools" +) + +func TestToolset_ContainsTool(t *testing.T) { + t.Parallel() + + toolset := tools.Toolset{ + ToolsetConfig: tools.ToolsetConfig{ + Name: "test-toolset", + ToolNames: []string{"echo", "list_tables"}, + }, + } + + tests := []struct { + name string + toolName string + want bool + }{ + { + name: "tool exists in toolset", + toolName: "echo", + want: true, + }, + { + name: "another tool exists in toolset", + toolName: "list_tables", + want: true, + }, + { + name: "tool not in toolset", + toolName: "admin_delete", + want: false, + }, + { + name: "empty tool name", + toolName: "", + want: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := toolset.ContainsTool(tc.toolName) + if got != tc.want { + t.Errorf("ContainsTool(%q) = %v, want %v", tc.toolName, got, tc.want) + } + }) + } +} + +func TestToolset_ContainsTool_EmptyToolset(t *testing.T) { + t.Parallel() + + toolset := tools.Toolset{ + ToolsetConfig: tools.ToolsetConfig{ + Name: "empty-toolset", + ToolNames: []string{}, + }, + } + + if toolset.ContainsTool("anything") { + t.Error("ContainsTool should return false for empty toolset") + } +} From 1e5a7f28f28a6c5adb5c3542ab235a90e6673c62 Mon Sep 17 00:00:00 2001 From: sjhddh Date: Fri, 24 Apr 2026 20:31:20 +0200 Subject: [PATCH 2/2] perf: use O(1) map lookups for ContainsTool/ContainsPrompt Addresses review feedback on the boundary-enforcement gate: previously ContainsTool/ContainsPrompt performed a linear scan of ToolNames/PromptNames on every tools/call and prompts/get. For large toolsets/promptsets this is a measurable overhead on a hot path. Introduce unexported toolNameSet/promptNameSet maps, populated during Initialize, and use them for O(1) membership checks. Fall back to the linear scan when the Toolset/Promptset was constructed directly (e.g., in tests that bypass Initialize), preserving the zero-value contract. Test comparisons that diff full Toolset/Promptset structs now use cmp.AllowUnexported / cmpopts.IgnoreUnexported so the private cache doesn't affect equality. All existing tests pass. --- internal/prompts/promptsets.go | 15 ++++++++++++--- internal/prompts/promptsets_test.go | 5 +++-- internal/server/resources/resources_test.go | 4 ++-- internal/server/server_test.go | 4 ++-- internal/tools/toolsets.go | 9 +++++++++ 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/internal/prompts/promptsets.go b/internal/prompts/promptsets.go index a147ca9c8917..252473136b50 100644 --- a/internal/prompts/promptsets.go +++ b/internal/prompts/promptsets.go @@ -27,9 +27,10 @@ type PromptsetConfig struct { type Promptset struct { PromptsetConfig - Prompts []*Prompt `yaml:",inline"` - Manifest PromptsetManifest `yaml:",inline"` - McpManifest []McpManifest `yaml:",inline"` + Prompts []*Prompt `yaml:",inline"` + Manifest PromptsetManifest `yaml:",inline"` + McpManifest []McpManifest `yaml:",inline"` + promptNameSet map[string]struct{} } func (p Promptset) ToConfig() PromptsetConfig { @@ -37,7 +38,13 @@ func (p Promptset) ToConfig() PromptsetConfig { } // ContainsPrompt reports whether the promptset includes a prompt with the given name. +// When built via Initialize, lookups are O(1) via promptNameSet; for Promptsets +// constructed directly (e.g., in tests), falls back to a linear scan of PromptNames. func (p Promptset) ContainsPrompt(name string) bool { + if p.promptNameSet != nil { + _, ok := p.promptNameSet[name] + return ok + } for _, n := range p.PromptNames { if n == name { return true @@ -64,6 +71,7 @@ func (t PromptsetConfig) Initialize(serverVersion string, promptsMap map[string] ServerVersion: serverVersion, PromptsManifest: make(map[string]Manifest, len(t.PromptNames)), } + promptset.promptNameSet = make(map[string]struct{}, len(t.PromptNames)) for _, promptName := range t.PromptNames { prompt, ok := promptsMap[promptName] if !ok { @@ -72,6 +80,7 @@ func (t PromptsetConfig) Initialize(serverVersion string, promptsMap map[string] promptset.Prompts = append(promptset.Prompts, &prompt) promptset.Manifest.PromptsManifest[promptName] = prompt.Manifest() promptset.McpManifest = append(promptset.McpManifest, prompt.McpManifest()) + promptset.promptNameSet[promptName] = struct{}{} } return promptset, nil diff --git a/internal/prompts/promptsets_test.go b/internal/prompts/promptsets_test.go index 8ed1900cb58c..d84e734e9d63 100644 --- a/internal/prompts/promptsets_test.go +++ b/internal/prompts/promptsets_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/mcp-toolbox/internal/prompts" "github.com/googleapis/mcp-toolbox/internal/util/parameters" ) @@ -271,7 +272,7 @@ func TestPromptsetConfig_Initialize(t *testing.T) { t.Errorf("Initialize() error mismatch:\n want to contain: %q\n got: %q", tc.wantErr, err.Error()) } // Also check that the partially populated struct matches - if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(mockPrompt{})); diff != "" { + if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(mockPrompt{}), cmpopts.IgnoreUnexported(prompts.Promptset{})); diff != "" { t.Errorf("Initialize() partial result on error mismatch (-want +got):\n%s", diff) } } else { @@ -279,7 +280,7 @@ func TestPromptsetConfig_Initialize(t *testing.T) { t.Fatalf("Initialize() returned unexpected error: %v", err) } // Using cmp.AllowUnexported because mockPrompt is unexported - if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(mockPrompt{})); diff != "" { + if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(mockPrompt{}), cmpopts.IgnoreUnexported(prompts.Promptset{})); diff != "" { t.Errorf("Initialize() result mismatch (-want +got):\n%s", diff) } } diff --git a/internal/server/resources/resources_test.go b/internal/server/resources/resources_test.go index b1a6da811838..64940e2e881f 100644 --- a/internal/server/resources/resources_test.go +++ b/internal/server/resources/resources_test.go @@ -74,7 +74,7 @@ func TestUpdateServer(t *testing.T) { } gotToolset, _ := resMgr.GetToolset("example-toolset") - if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"]); diff != "" { + if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"], cmp.AllowUnexported(tools.Toolset{})); diff != "" { t.Errorf("error updating server, toolset (-want +got):\n%s", diff) } @@ -84,7 +84,7 @@ func TestUpdateServer(t *testing.T) { } gotPromptset, _ := resMgr.GetPromptset("example-promptset") - if diff := cmp.Diff(gotPromptset, newPromptsets["example-promptset"]); diff != "" { + if diff := cmp.Diff(gotPromptset, newPromptsets["example-promptset"], cmp.AllowUnexported(prompts.Promptset{})); diff != "" { t.Errorf("error updating server, promptset (-want +got):\n%s", diff) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ae0510b77675..c697ce731854 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -190,7 +190,7 @@ func TestUpdateServer(t *testing.T) { } gotToolset, _ := s.ResourceMgr.GetToolset("example-toolset") - if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"]); diff != "" { + if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"], cmp.AllowUnexported(tools.Toolset{})); diff != "" { t.Errorf("error updating server, toolset (-want +got):\n%s", diff) } @@ -200,7 +200,7 @@ func TestUpdateServer(t *testing.T) { } gotPromptset, _ := s.ResourceMgr.GetPromptset("example-promptset") - if diff := cmp.Diff(gotPromptset, newPromptsets["example-promptset"]); diff != "" { + if diff := cmp.Diff(gotPromptset, newPromptsets["example-promptset"], cmp.AllowUnexported(prompts.Promptset{})); diff != "" { t.Errorf("error updating server, promptset (-want +got):\n%s", diff) } } diff --git a/internal/tools/toolsets.go b/internal/tools/toolsets.go index 008ba844e5c6..ab0a21e6770f 100644 --- a/internal/tools/toolsets.go +++ b/internal/tools/toolsets.go @@ -29,6 +29,7 @@ type Toolset struct { Tools []*Tool `yaml:",inline"` Manifest ToolsetManifest `yaml:",inline"` McpManifest []McpManifest `yaml:",inline"` + toolNameSet map[string]struct{} } func (t Toolset) ToConfig() ToolsetConfig { @@ -36,7 +37,13 @@ func (t Toolset) ToConfig() ToolsetConfig { } // ContainsTool reports whether the toolset includes a tool with the given name. +// When built via Initialize, lookups are O(1) via toolNameSet; for Toolsets +// constructed directly (e.g., in tests), falls back to a linear scan of ToolNames. func (t Toolset) ContainsTool(name string) bool { + if t.toolNameSet != nil { + _, ok := t.toolNameSet[name] + return ok + } for _, n := range t.ToolNames { if n == name { return true @@ -63,6 +70,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool ServerVersion: serverVersion, ToolsManifest: make(map[string]Manifest), } + toolset.toolNameSet = make(map[string]struct{}, len(t.ToolNames)) for _, toolName := range t.ToolNames { tool, ok := toolsMap[toolName] if !ok { @@ -71,6 +79,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool toolset.Tools = append(toolset.Tools, &tool) toolset.Manifest.ToolsManifest[toolName] = tool.Manifest() toolset.McpManifest = append(toolset.McpManifest, tool.McpManifest()) + toolset.toolNameSet[toolName] = struct{}{} } return toolset, nil