-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix: enforce toolset/promptset boundary on tools/call and prompts/get #3036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too | |
| case TOOLS_LIST: | ||
| return toolsListHandler(id, toolset, body) | ||
| case TOOLS_CALL: | ||
| return toolsCallHandler(ctx, id, resourceMgr, body, header) | ||
| return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) | ||
| case PROMPTS_LIST: | ||
| return promptsListHandler(ctx, id, promptset, body) | ||
| case PROMPTS_GET: | ||
| return promptsGetHandler(ctx, id, resourceMgr, body) | ||
| return promptsGetHandler(ctx, id, promptset, resourceMgr, body) | ||
| default: | ||
| err := fmt.Errorf("invalid method %s", method) | ||
| return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err | ||
|
|
@@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) | |
| } | ||
|
|
||
| // toolsCallHandler generate a response for tools call. | ||
| func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { | ||
| func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { | ||
| authServices := resourceMgr.GetAuthServiceMap() | ||
|
|
||
| // retrieve logger from context | ||
|
|
@@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re | |
| attribute.String("gen_ai.operation.name", "execute_tool"), | ||
| ) | ||
|
|
||
| // Verify tool belongs to the current toolset before resolving globally. | ||
| if !toolset.ContainsTool(toolName) { | ||
| err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) | ||
| return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should return the same error message as non-existent tool call (invalid tool name: tool with name %q does not exist) . Otherwise if the tool doesn't exist at all, we will still return "tool %q is not part of the current toolset" which could be misleading. |
||
| } | ||
|
|
||
| tool, ok := resourceMgr.GetTool(toolName) | ||
| if !ok { | ||
| err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) | ||
|
|
@@ -339,7 +345,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro | |
| } | ||
|
|
||
| // promptsGetHandler handles the "prompts/get" method. | ||
| func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { | ||
| func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { | ||
| // retrieve logger from context | ||
| logger, err := util.LoggerFromContext(ctx) | ||
| if err != nil { | ||
|
|
@@ -361,6 +367,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r | |
| span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) | ||
| span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) | ||
|
|
||
| // Verify prompt belongs to the current promptset before resolving globally. | ||
| if !promptset.ContainsPrompt(promptName) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar for prompt-let's treat it the same way as a non-existent prompt. |
||
| err := fmt.Errorf("prompt with name %q does not exist", promptName) | ||
| return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err | ||
| } | ||
|
|
||
| prompt, ok := resourceMgr.GetPrompt(promptName) | ||
| if !ok { | ||
| err := fmt.Errorf("prompt with name %q does not exist", promptName) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too | |
| case TOOLS_LIST: | ||
| return toolsListHandler(id, toolset, body) | ||
| case TOOLS_CALL: | ||
| return toolsCallHandler(ctx, id, resourceMgr, body, header) | ||
| return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) | ||
| case PROMPTS_LIST: | ||
| return promptsListHandler(ctx, id, promptset, body) | ||
| case PROMPTS_GET: | ||
| return promptsGetHandler(ctx, id, resourceMgr, body) | ||
| return promptsGetHandler(ctx, id, promptset, resourceMgr, body) | ||
| default: | ||
| err := fmt.Errorf("invalid method %s", method) | ||
| return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err | ||
|
|
@@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) | |
| } | ||
|
|
||
| // toolsCallHandler generate a response for tools call. | ||
| func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { | ||
| func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { | ||
| authServices := resourceMgr.GetAuthServiceMap() | ||
|
|
||
| // retrieve logger from context | ||
|
|
@@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re | |
| attribute.String("gen_ai.operation.name", "execute_tool"), | ||
| ) | ||
|
|
||
| // Verify tool belongs to the current toolset before resolving globally. | ||
| if !toolset.ContainsTool(toolName) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for the other errors. |
||
| err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) | ||
| return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err | ||
| } | ||
|
|
||
| tool, ok := resourceMgr.GetTool(toolName) | ||
| if !ok { | ||
| err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) | ||
|
|
@@ -338,7 +344,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro | |
| } | ||
|
|
||
| // promptsGetHandler handles the "prompts/get" method. | ||
| func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { | ||
| func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { | ||
| // retrieve logger from context | ||
| logger, err := util.LoggerFromContext(ctx) | ||
| if err != nil { | ||
|
|
@@ -360,6 +366,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r | |
| span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) | ||
| span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) | ||
|
|
||
| // Verify prompt belongs to the current promptset before resolving globally. | ||
| if !promptset.ContainsPrompt(promptName) { | ||
| err := fmt.Errorf("prompt with name %q does not exist", promptName) | ||
| return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err | ||
| } | ||
|
|
||
| prompt, ok := resourceMgr.GetPrompt(promptName) | ||
| if !ok { | ||
| err := fmt.Errorf("prompt with name %q does not exist", promptName) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,16 @@ func (t Toolset) ToConfig() ToolsetConfig { | |
| return t.ToolsetConfig | ||
| } | ||
|
|
||
| // ContainsTool reports whether the toolset includes a tool with the given name. | ||
| func (t Toolset) ContainsTool(name string) bool { | ||
| for _, n := range t.ToolNames { | ||
| if n == name { | ||
| return true | ||
| } | ||
| } | ||
| return false | ||
| } | ||
|
Comment on lines
+39
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| type ToolsetManifest struct { | ||
| ServerVersion string `json:"serverVersion"` | ||
| ToolsManifest map[string]Manifest `json:"tools"` | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The$O(1)$ lookups, which could be initialized once in the
ContainsPromptmethod uses a linear search ($O(N)$), which is executed on everyprompts/getrequest. While likely acceptable for small promptsets, this could impact performance as the number of prompts grows. Consider using a map forInitializemethod.