Skip to content
57 changes: 52 additions & 5 deletions py/packages/genkit/src/genkit/_ai/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources
from genkit._ai._tools import Tool, ToolInterruptError
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._dap import GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR
from genkit._core._error import GenkitError
from genkit._core._logger import get_logger
from genkit._core._model import GenerateActionOptions
Expand All @@ -61,6 +62,46 @@
logger = get_logger(__name__)


async def expand_wildcard_tools(registry: Registry, tool_names: list[str]) -> list[str]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works a bit differently from the javascript implementation and there is a critical bug here regarding how the expanded tool names are returned.

Currently, the function strips the provider prefix and appends only the unqualified tool name:

tool_name = meta.get('name')
if tool_name:
    expanded.append(str(tool_name))

This creates a serious conflict resolution issue because it drops the explicit provider binding. When the generation loop later attempts to resolve this unqualified name (e.g., 'echo'), the registry will fall back to querying all registered DAPs in registration order to find a match.

The Bug Scenario:

  1. You have two MCP servers registered: mcp1 (registered first) and mcp2 (registered second).
  2. Both servers expose a tool named 'echo'.
  3. The user explicitly requests tools from only the second server: tools=['mcp2:tool/*'].
  4. This function expands that request to just ['echo'].
  5. When the framework resolves 'echo', it asks the DAPs in order. mcp1 replies first saying "I have echo!" and its tool is executed instead of mcp2's.

The Fix: To maintain the explicit provider binding (and match how the JS implementation behaves), we should reconstruct the fully-qualified DAP key so the registry knows exactly which provider to query later on:

tool_name = meta.get('name')
if tool_name:
    expanded.append(f'/dynamic-action-provider/{provider_name}:{action_type}/{tool_name}')

This guarantees the executed tool will strictly belong to the provider the user requested.

It would also be a good idea to have a test for this scenario.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK good catch, fixed this issue. The scenario above should also be blocked by assert_valid_tool_names(), since if there are two tools in scope with the same short name that will raise.

"""Expand DAP wildcard tool names into individual tool names.

A wildcard has the form ``<provider>:tool/*`` (or ``<provider>:tool/<prefix>*``).
Non-wildcard names are passed through unchanged.
"""
expanded: list[str] = []
for name in tool_names:
if ':' not in name or not name.endswith('*'):
expanded.append(name)
continue

colon = name.index(':')
provider_name = name[:colon]
rest = name[colon + 1:] # e.g. "tool/*" or "tool/prefix*"

provider_action = await registry.resolve_action(ActionKind.DYNAMIC_ACTION_PROVIDER, provider_name)
if provider_action is None:
expanded.append(name)
continue

dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is None:
expanded.append(name)
continue

if '/' not in rest:
expanded.append(name)
continue

action_type, action_pattern = rest.split('/', 1)
metas = await dap.list_action_metadata(action_type, action_pattern)
for meta in metas:
tool_name = meta.get('name')
if tool_name:
expanded.append(str(tool_name))

return expanded


def tools_to_action_names(
tools: Sequence[str | Tool] | None,
) -> list[str] | None:
Expand Down Expand Up @@ -158,20 +199,26 @@ async def _generate_action(
context: dict[str, Any] | None = None,
) -> ModelResponse:
"""Execute a generation request with tool calling and middleware support."""
model, tools, format_def = await resolve_parameters(registry, raw_request)
effective_registry = registry if registry.is_child else registry.new_child()

if raw_request.tools:
raw_request = raw_request.model_copy()
raw_request.tools = await expand_wildcard_tools(effective_registry, raw_request.tools)

model, tools, format_def = await resolve_parameters(effective_registry, raw_request)

raw_request, formatter = apply_format(raw_request, format_def)

if raw_request.resources:
raw_request = await apply_resources(registry, raw_request)
raw_request = await apply_resources(effective_registry, raw_request)

assert_valid_tool_names(raw_request)

(
revised_request,
interrupted_response,
resumed_tool_message,
) = await _resolve_resume_options(registry, raw_request)
) = await _resolve_resume_options(effective_registry, raw_request)

# NOTE: in the future we should make it possible to interrupt a restart, but
# at the moment it's too complicated because it's not clear how to return a
Expand Down Expand Up @@ -374,7 +421,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401
revised_model_msg,
tool_msg,
transfer_preamble,
) = await resolve_tool_requests(registry, raw_request, generated_msg)
) = await resolve_tool_requests(effective_registry, raw_request, generated_msg)

# if an interrupt message is returned, stop the tool loop and return a
# response.
Expand Down Expand Up @@ -408,7 +455,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401

# then recursively call for another loop
return await _generate_action(
registry,
effective_registry,
raw_request=next_request,
# middleware: middleware,
current_turn=current_turn + 1,
Expand Down
236 changes: 236 additions & 0 deletions py/packages/genkit/tests/genkit/ai/dynamic_tools_generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for DAP-backed tool resolution in the generate loop."""

import pytest
from pydantic import BaseModel

from genkit._ai._generate import expand_wildcard_tools, generate_action
from genkit._ai._testing import define_programmable_model
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._dap import DapValue, define_dynamic_action_provider
from genkit._core._model import GenerateActionOptions, ModelRequest
from genkit._core._registry import Registry
from genkit._core._typing import (
FinishReason,
Part,
Role,
TextPart,
ToolRequest,
ToolRequestPart,
)
from genkit import Genkit, Message, ModelResponse, ModelResponseChunk


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _text_response(text: str) -> ModelResponse:
return ModelResponse(
message=Message(role=Role.MODEL, content=[Part(root=TextPart(text=text))]),
finish_reason=FinishReason.STOP,
)


def _tool_call_response(tool_name: str, input: dict) -> ModelResponse:
return ModelResponse(
message=Message(
role=Role.MODEL,
content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name=tool_name, input=input, ref=tool_name)))],
),
finish_reason=FinishReason.STOP,
)


# ---------------------------------------------------------------------------
# expand_wildcard_tools
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_expand_wildcard_all() -> None:
"""'provider:tool/*' expands to all tools from the DAP."""
registry = Registry()

async def tool_fn(x: str) -> str:
return x

t1 = registry.register_action(name='echo', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'echo'})
t2 = registry.register_action(name='ping', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'ping'})

async def dap_fn() -> DapValue:
return {'tool': [t1, t2]}

define_dynamic_action_provider(registry, 'mcp', dap_fn)

result = await expand_wildcard_tools(registry, ['mcp:tool/*'])
assert sorted(result) == ['echo', 'ping']


@pytest.mark.asyncio
async def test_expand_wildcard_prefix() -> None:
"""'provider:tool/prefix*' expands only matching tools."""
registry = Registry()

async def tool_fn(x: str) -> str:
return x

t1 = registry.register_action(name='get_weather', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'get_weather'})
t2 = registry.register_action(name='get_time', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'get_time'})
t3 = registry.register_action(name='set_alarm', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'set_alarm'})

async def dap_fn() -> DapValue:
return {'tool': [t1, t2, t3]}

define_dynamic_action_provider(registry, 'mcp', dap_fn)

result = await expand_wildcard_tools(registry, ['mcp:tool/get_*'])
assert sorted(result) == ['get_time', 'get_weather']


@pytest.mark.asyncio
async def test_non_wildcard_names_pass_through() -> None:
"""Non-wildcard names are returned unchanged."""
registry = Registry()
result = await expand_wildcard_tools(registry, ['my_tool', 'other_tool'])
assert result == ['my_tool', 'other_tool']

Comment thread
huangjeff5 marked this conversation as resolved.

# ---------------------------------------------------------------------------
# DAP tools resolved inside generate loop
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_dap_tool_resolved_in_generate() -> None:
"""generate resolves a tool that lives only in a DAP, calls it, and gets final answer."""
ai = Genkit()
pm, _ = define_programmable_model(ai)

call_log: list[str] = []

class EchoInput(BaseModel):
text: str

async def echo_fn(inp: EchoInput) -> str:
call_log.append(inp.text)
return f'echoed:{inp.text}'

echo_action = ai.registry.register_action(
Comment thread
huangjeff5 marked this conversation as resolved.
Outdated
name='echo',
kind=ActionKind.TOOL,
fn=echo_fn,
metadata={'name': 'echo'},
)

async def dap_fn() -> DapValue:
return {'tool': [echo_action]}

ai.define_dynamic_action_provider('mcp', dap_fn)

# Turn 1: model asks to call 'echo'
pm.responses = [
_tool_call_response('echo', {'text': 'hello'}),
_text_response('done'),
]

response = await ai.generate(
model='programmableModel',
prompt='use echo',
tools=['echo'],
)

assert response.text == 'done'
assert call_log == ['hello']


@pytest.mark.asyncio
async def test_dap_tools_do_not_pollute_root_registry() -> None:
"""After generate, DAP-resolved tools are not cached in the root registry."""
ai = Genkit()
pm, _ = define_programmable_model(ai)

class Inp(BaseModel):
x: str

async def tool_fn(inp: Inp) -> str:
return inp.x

# Create an Action directly — NOT registered in root via register_action
dap_only_action = Action(name='dap_only_tool', kind=ActionKind.TOOL, fn=tool_fn,
metadata={'name': 'dap_only_tool'})

async def dap_fn() -> DapValue:
return {'tool': [dap_only_action]}

ai.define_dynamic_action_provider('mcp', dap_fn)

pm.responses = [_text_response('no tools called')]

await ai.generate(
model='programmableModel',
prompt='hi',
tools=['dap_only_tool'],
)

# Root registry should NOT have dap_only_tool cached — it was never registered there
root_tools = ai.registry._entries.get(ActionKind.TOOL, {})
assert 'dap_only_tool' not in root_tools


@pytest.mark.asyncio
async def test_wildcard_tools_in_generate() -> None:
"""Wildcard tool pattern is expanded before generate resolves tools."""
ai = Genkit()
pm, _ = define_programmable_model(ai)

call_log: list[str] = []

class InpA(BaseModel):
x: str

class InpB(BaseModel):
x: str

async def tool_a_fn(inp: InpA) -> str:
call_log.append(f'a:{inp.x}')
return f'a:{inp.x}'

async def tool_b_fn(inp: InpB) -> str:
call_log.append(f'b:{inp.x}')
return f'b:{inp.x}'

tool_a = ai.registry.register_action(name='tool_a', kind=ActionKind.TOOL, fn=tool_a_fn, metadata={'name': 'tool_a'})
tool_b = ai.registry.register_action(name='tool_b', kind=ActionKind.TOOL, fn=tool_b_fn, metadata={'name': 'tool_b'})

async def dap_fn() -> DapValue:
return {'tool': [tool_a, tool_b]}

ai.define_dynamic_action_provider('mcp', dap_fn)

pm.responses = [
_tool_call_response('tool_a', {'x': 'hi'}),
_text_response('finished'),
]

response = await ai.generate(
model='programmableModel',
prompt='use a tool',
tools=['mcp:tool/*'],
)

assert response.text == 'finished'
assert call_log == ['a:hi']
Loading