Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion temporalio/contrib/google_adk_agents/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class TemporalMcpToolSetProvider:
within Temporal workflows.
"""

def __init__(self, name: str, toolset_factory: Callable[[Any | None], McpToolset]):
def __init__(
self, name: str, toolset_factory: Callable[[Any | None], McpToolset]
) -> None:
"""Initializes the toolset provider.

Args:
Expand Down Expand Up @@ -215,20 +217,23 @@ def __init__(
name: str,
config: ActivityConfig | None = None,
factory_argument: Any | None = None,
local_toolset: Callable[[Any | None], McpToolset] | None = None,
):
"""Initializes the Temporal MCP toolset.

Args:
name: Name of the toolset (used for activity naming).
config: Optional activity configuration.
factory_argument: Optional argument passed to toolset factory.
local_toolset: Optional factory for a temporal toolset for local execution when running outside a durable workflow.
"""
super().__init__()
self._name = name
self._factory_argument = factory_argument
self._config = config or ActivityConfig(
start_to_close_timeout=timedelta(minutes=1)
)
self._local_toolset = local_toolset

async def get_tools(
self, readonly_context: ReadonlyContext | None = None
Expand All @@ -241,6 +246,14 @@ async def get_tools(
Returns:
List of available tools wrapped as Temporal activities.
"""
# If executed outside a workflow, like when doing local adk runs, use the mcp server directly
if not workflow.in_workflow():
if self._local_toolset is None:
raise ValueError(
"Attempted to execute an MCP tool declared with TemporalMcpToolSet outside of a Workflow. Either use McpToolSet or pass a copy of your MCP toolset provider into local_toolset."
)
return await self._local_toolset(None).get_tools(readonly_context)

tool_results: list[_ToolResult] = await workflow.execute_activity(
self._name + "-list-tools",
_GetToolsArguments(self._factory_argument),
Expand Down
9 changes: 9 additions & 0 deletions temporalio/contrib/google_adk_agents/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse

import temporalio.workflow
from temporalio import activity, workflow
from temporalio.workflow import ActivityConfig

Expand Down Expand Up @@ -67,6 +68,14 @@ async def generate_content_async(
Yields:
The responses from the model.
"""
# If executed outside a workflow, like when doing local adk runs, use the model directly
if not temporalio.workflow.in_workflow():
async for response in LLMRegistry.new_llm(
self._model_name
).generate_content_async(llm_request, stream=stream):
yield response
return

responses = await workflow.execute_activity(
invoke_model,
args=[llm_request],
Expand Down
9 changes: 9 additions & 0 deletions temporalio/contrib/google_adk_agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
from typing import Any, Callable

import temporalio.workflow
from temporalio import workflow


Expand All @@ -29,6 +30,14 @@ async def wrapper(*args: Any, **kw: Any):
# Decorator kwargs are defaults.
options = kwargs.copy()

if not temporalio.workflow.in_workflow():
# If executed outside a workflow, like when doing local adk runs, use the function directly
result = activity_def(*args, **kw)
if inspect.isawaitable(result):
return await result
else:
return result

return await workflow.execute_activity(activity_def, *activity_args, **options)

# Copy metadata
Expand Down
210 changes: 151 additions & 59 deletions tests/contrib/google_adk_agents/test_google_adk_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import os
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterator
from collections.abc import AsyncGenerator
from datetime import timedelta
from typing import Any

import pytest
from google.adk import Agent, Runner
Expand Down Expand Up @@ -64,6 +65,19 @@ async def get_weather(city: str) -> str: # type: ignore[reportUnusedParameter]
return "Warm and sunny. 17 degrees."


def weather_agent(model_name: str) -> Agent:
# Wraps 'get_weather' activity as a Tool
weather_tool = temporalio.contrib.google_adk_agents.workflow.activity_tool(
get_weather, start_to_close_timeout=timedelta(seconds=60)
)

return Agent(
name="test_agent",
model=TemporalModel(model_name),
tools=[weather_tool],
)


@workflow.defn
class WeatherAgent:
@workflow.run
Expand All @@ -73,17 +87,7 @@ async def run(self, prompt: str, model_name: str) -> Event | None:
# 1. Define Agent using Temporal Helpers
# Note: AgentPlugin in the Runner automatically handles Runtime setup
# and Model Activity interception. We use standard ADK models now.

# Wraps 'get_weather' activity as a Tool
weather_tool = temporalio.contrib.google_adk_agents.workflow.activity_tool(
get_weather, start_to_close_timeout=timedelta(seconds=60)
)

agent = Agent(
name="test_agent",
model=TemporalModel(model_name),
tools=[weather_tool],
)
agent = weather_agent(model_name)

# 2. Create runner
runner = InMemoryRunner(
Expand Down Expand Up @@ -357,21 +361,38 @@ async def test_multi_agent(client: Client, use_local_model: bool):
assert result == "haiku"


def example_toolset(_: Any | None) -> McpToolset:
return McpToolset(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
command="npx",
args=[
"-y",
"@modelcontextprotocol/server-filesystem",
os.path.dirname(os.path.abspath(__file__)),
],
),
),
)


def mcp_agent(model_name: str) -> Agent:
return Agent(
name="test_agent",
# instruction="Always use your tools to answer questions.",
model=TemporalModel(model_name),
tools=[TemporalMcpToolSet("test_set", local_toolset=example_toolset)],
)


@workflow.defn
class McpAgent:
@workflow.run
async def run(self, prompt: str, model_name: str) -> str:
logger.info("Workflow started.")

# 1. Define Agent using Temporal Helpers
# Note: AgentPlugin in the Runner automatically handles Runtime setup
# and Model Activity interception. We use standard ADK models now.
agent = Agent(
name="test_agent",
# instruction="Always use your tools to answer questions.",
model=TemporalModel(model_name),
tools=[TemporalMcpToolSet("test_set")],
)
agent = mcp_agent(model_name)

# 2. Create Session (uses runtime.new_uuid() -> workflow.uuid4())
session_service = InMemorySessionService()
Expand Down Expand Up @@ -408,39 +429,36 @@ async def run(self, prompt: str, model_name: str) -> str:
return last_event.content.parts[0].text


class McpModel(BaseLlm):
responses: list[LlmResponse] = [
LlmResponse(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
args={"path": os.path.dirname(os.path.abspath(__file__))},
name="list_directory",
class McpModel(TestModel):
def responses(self) -> list[LlmResponse]:
return [
LlmResponse(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
args={
"path": os.path.dirname(os.path.abspath(__file__))
},
name="list_directory",
)
)
)
],
)
),
LlmResponse(
content=Content(
role="model",
parts=[Part(text="Some files.")],
)
),
]
response_iter: Iterator[LlmResponse] = iter(responses)
],
)
),
LlmResponse(
content=Content(
role="model",
parts=[Part(text="Some files.")],
)
),
]

@classmethod
def supported_models(cls) -> list[str]:
return ["mcp_model"]

async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
yield next(self.response_iter)


@pytest.mark.parametrize("use_local_model", [True, False])
@pytest.mark.asyncio
Expand All @@ -455,18 +473,7 @@ async def test_mcp_agent(client: Client, use_local_model: bool):
toolset_providers=[
TemporalMcpToolSetProvider(
"test_set",
lambda _: McpToolset(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
command="npx",
args=[
"-y",
"@modelcontextprotocol/server-filesystem",
os.path.dirname(os.path.abspath(__file__)),
],
),
),
),
example_toolset,
)
],
)
Expand Down Expand Up @@ -567,3 +574,88 @@ async def test_single_agent_telemetry(client: Client):
async def test_unsetting_timeout():
model = TemporalModel("", ActivityConfig(start_to_close_timeout=None))
assert model._activity_config.get("start_to_close_timeout", None) is None


@pytest.mark.asyncio
async def test_agent_outside_workflow():
"""Test that an agent using TemporalModel and activity_tool works outside a Temporal workflow."""
LLMRegistry.register(WeatherModel)

agent = weather_agent("weather_model")

runner = InMemoryRunner(
agent=agent,
app_name="test_app_local",
)

session = await runner.session_service.create_session(
app_name="test_app_local", user_id="test"
)

last_event = None
async with Aclosing(
runner.run_async(
user_id="test",
session_id=session.id,
new_message=types.Content(
role="user", parts=[types.Part(text="What is the weather in New York?")]
),
)
) as agen:
async for event in agen:
last_event = event

assert last_event is not None
assert last_event.content is not None
assert last_event.content.parts is not None
assert last_event.content.parts[0].text == "warm and sunny"


@pytest.mark.asyncio
@pytest.mark.skip # Doesn't work well in CI currently
async def test_mcp_agent_outside_workflow():
"""Test that an agent using TemporalMcpToolSet works outside a Temporal workflow."""
LLMRegistry.register(McpModel)

agent = mcp_agent("mcp_model")

session_service = InMemorySessionService()
session = await session_service.create_session(
app_name="test_app_local", user_id="test"
)

runner = Runner(
agent=agent,
app_name="test_app_local",
session_service=session_service,
)

last_event = None
async with Aclosing(
runner.run_async(
user_id="test",
session_id=session.id,
new_message=types.Content(
role="user",
parts=[types.Part(text="What files are in the current directory?")],
),
)
) as agen:
async for event in agen:
last_event = event

assert last_event is not None
assert last_event.content is not None
assert last_event.content.parts is not None
assert last_event.content.parts[0].text == "Some files."


@pytest.mark.asyncio
async def test_mcp_toolset_outside_workflow_no_local_toolset():
"""Test that TemporalMcpToolSet raises ValueError outside a workflow with no local_toolset."""
toolset = TemporalMcpToolSet("test_set_no_local")
with pytest.raises(
ValueError,
match="Attempted to execute an MCP tool",
):
await toolset.get_tools()
Loading