Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,24 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
**Limitation**: The callable is *not serializable*; values provided via YAML/JSON configs are ignored.
tool_call_error_function (Callable[[Exception, FunctionCall], str | None] | None, optional):
Callable that handles exceptions raised during tool execution. When provided,
tool exceptions are no longer silently caught and stringified. Instead, the callable
receives the exception and the FunctionCall, and should return:
- A ``str``: used as the error content in the ``FunctionExecutionResult`` (the error is handled).
- ``None``: the exception is re-raised and propagates up the call stack (the error is fatal).
This allows fine-grained control over which errors are recoverable and which should
halt execution. When not set (default), all tool errors are caught and their string
representation is passed to the model.
**Limitation**: The callable is *not serializable*; values provided via YAML/JSON configs are ignored.
.. note::
`tool_call_summary_formatter` is intended for in-code use only. It cannot currently be saved or restored via
configuration files.
`tool_call_summary_formatter` and `tool_call_error_function` are intended for in-code use only.
They cannot currently be saved or restored via configuration files.
memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`.
metadata (Dict[str, str] | None, optional): Optional metadata for tracking.
Expand Down Expand Up @@ -739,6 +753,7 @@ def __init__(
max_tool_iterations: int = 1,
tool_call_summary_format: str = "{result}",
tool_call_summary_formatter: Callable[[FunctionCall, FunctionExecutionResult], str] | None = None,
tool_call_error_function: Callable[[Exception, FunctionCall], str | None] | None = None,
output_content_type: type[BaseModel] | None = None,
output_content_type_format: str | None = None,
memory: Sequence[Memory] | None = None,
Expand Down Expand Up @@ -832,7 +847,7 @@ def __init__(
else:
self._workbench = [workbench]
else:
self._workbench = [StaticStreamWorkbench(self._tools)]
self._workbench = [StaticStreamWorkbench(self._tools, raise_on_error=tool_call_error_function is not None)]

if model_context is not None:
self._model_context = model_context
Expand All @@ -856,6 +871,7 @@ def __init__(

self._tool_call_summary_format = tool_call_summary_format
self._tool_call_summary_formatter = tool_call_summary_formatter
self._tool_call_error_function = tool_call_error_function
self._is_running = False

@property
Expand Down Expand Up @@ -1007,6 +1023,7 @@ async def on_messages_stream(
output_content_type=output_content_type,
message_id=message_id,
format_string=self._output_content_type_format,
tool_call_error_function=self._tool_call_error_function,
):
yield output_event

Expand Down Expand Up @@ -1135,6 +1152,7 @@ async def _process_model_result(
output_content_type: type[BaseModel] | None,
message_id: str,
format_string: str | None = None,
tool_call_error_function: Callable[[Exception, FunctionCall], str | None] | None = None,
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Handle final or partial responses from model_result, including tool calls, handoffs,
Expand Down Expand Up @@ -1197,19 +1215,25 @@ async def _execute_tool_calls(
function_calls: List[FunctionCall],
stream_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None],
) -> List[Tuple[FunctionCall, FunctionExecutionResult]]:
results = await asyncio.gather(
*[
cls._execute_tool_call(
tool_call=call,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
stream=stream_queue,
)
for call in function_calls
]
)
try:
results = await asyncio.gather(
*[
cls._execute_tool_call(
tool_call=call,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
stream=stream_queue,
tool_call_error_function=tool_call_error_function,
)
for call in function_calls
]
)
except Exception:
# Ensure the stream gets the sentinel so the consumer loop doesn't hang.
stream_queue.put_nowait(None)
raise
# Signal the end of streaming by putting None in the queue.
stream_queue.put_nowait(None)
return results
Expand Down Expand Up @@ -1540,8 +1564,17 @@ async def _execute_tool_call(
agent_name: str,
cancellation_token: CancellationToken,
stream: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None],
tool_call_error_function: Callable[[Exception, FunctionCall], str | None] | None = None,
) -> Tuple[FunctionCall, FunctionExecutionResult]:
"""Execute a single tool call and return the result."""
"""Execute a single tool call and return the result.
Args:
tool_call_error_function: Optional callable that handles tool execution errors.
When provided, the workbench is expected to raise exceptions (raise_on_error=True).
The callable receives the exception and the FunctionCall, and should return:
- A string: used as the error content in FunctionExecutionResult (error is handled).
- None: the exception is re-raised (error is fatal).
"""
# Load the arguments from the tool call.
try:
arguments = json.loads(tool_call.arguments)
Expand Down Expand Up @@ -1577,32 +1610,48 @@ async def _execute_tool_call(
for wb in workbench:
tools = await wb.list_tools()
if any(t["name"] == tool_call.name for t in tools):
if isinstance(wb, StaticStreamWorkbench):
tool_result: ToolResult | None = None
async for event in wb.call_tool_stream(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
):
if isinstance(event, ToolResult):
tool_result = event
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
await stream.put(event)
else:
warnings.warn(
f"Unexpected event type: {type(event)} in tool call streaming.",
UserWarning,
stacklevel=2,
)
assert isinstance(tool_result, ToolResult), "Tool result should not be None in streaming mode."
else:
tool_result = await wb.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
)
try:
if isinstance(wb, StaticStreamWorkbench):
tool_result: ToolResult | None = None
async for event in wb.call_tool_stream(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
):
if isinstance(event, ToolResult):
tool_result = event
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
await stream.put(event)
else:
warnings.warn(
f"Unexpected event type: {type(event)} in tool call streaming.",
UserWarning,
stacklevel=2,
)
assert isinstance(tool_result, ToolResult), "Tool result should not be None in streaming mode."
else:
tool_result = await wb.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
)
except Exception as e:
if tool_call_error_function is not None:
error_result = tool_call_error_function(e, tool_call)
if error_result is None:
raise
return (
tool_call,
FunctionExecutionResult(
content=error_result,
call_id=tool_call.id,
is_error=True,
name=tool_call.name,
),
)
raise
return (
tool_call,
FunctionExecutionResult(
Expand Down
116 changes: 116 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def _echo_function(input: str) -> str:
return input


def _pass_fail_tool(input: str) -> str:
"""A tool that always raises a RuntimeError for testing error handling.

Args:
input: Input string (unused)

Returns:
Never returns - always raises RuntimeError
"""
raise RuntimeError("pass fail tool")


class MockMemory(Memory):
"""Mock memory implementation for testing.

Expand Down Expand Up @@ -3560,3 +3572,107 @@ async def test_anthropic_basic_text_response(self) -> None:
usage = client.total_usage()
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0


class TestToolCallErrorFunction:
"""Tests for the tool_call_error_function parameter of AssistantAgent."""

@pytest.mark.asyncio
async def test_tool_error_raises_when_function_returns_none(self) -> None:
"""When tool_call_error_function returns None, the exception should be re-raised."""
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_fail_tool")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
model_client._model_info["function_calling"] = True # pyright: ignore

def error_handler(e: Exception, call: FunctionCall) -> str | None:
return None # Signal to re-raise

agent = AssistantAgent(
name="test_agent",
model_client=model_client,
tools=[_pass_fail_tool],
tool_call_error_function=error_handler,
)

with pytest.raises(RuntimeError, match="pass fail tool"):
await agent.run(task="test")

@pytest.mark.asyncio
async def test_tool_error_handled_when_function_returns_string(self) -> None:
"""When tool_call_error_function returns a string, it should be used as error content."""
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_fail_tool")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="handled the error",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
model_client._model_info["function_calling"] = True # pyright: ignore

def error_handler(e: Exception, call: FunctionCall) -> str | None:
return f"Custom error: {e}"

agent = AssistantAgent(
name="test_agent",
model_client=model_client,
tools=[_pass_fail_tool],
tool_call_error_function=error_handler,
)

result = await agent.run(task="test")
# The error should have been handled and returned as content
tool_call_results = [m for m in result.messages if isinstance(m, ToolCallExecutionEvent)]
assert len(tool_call_results) == 1
assert tool_call_results[0].content[0].is_error is True
assert "Custom error" in tool_call_results[0].content[0].content

@pytest.mark.asyncio
async def test_default_behavior_without_error_function(self) -> None:
"""Without tool_call_error_function, errors should be stringified as before."""
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_fail_tool")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="ok",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
model_client._model_info["function_calling"] = True # pyright: ignore

agent = AssistantAgent(
name="test_agent",
model_client=model_client,
tools=[_pass_fail_tool],
)

# Should not raise - errors are stringified by default
result = await agent.run(task="test")
tool_call_results = [m for m in result.messages if isinstance(m, ToolCallExecutionEvent)]
assert len(tool_call_results) == 1
assert tool_call_results[0].content[0].is_error is True
assert "pass fail tool" in tool_call_results[0].content[0].content
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,23 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
names to override configurations for name and/or description. This allows
customizing how tools appear to consumers while maintaining the underlying
tool functionality.
raise_on_error (bool): If True, exceptions from tool execution will be
re-raised instead of being caught and converted to error results.
Defaults to False.
"""

component_provider_override = "autogen_core.tools.StaticWorkbench"
component_config_schema = StaticWorkbenchConfig

def __init__(
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
self,
tools: List[BaseTool[Any, Any]],
tool_overrides: Optional[Dict[str, ToolOverride]] = None,
raise_on_error: bool = False,
) -> None:
self._tools = tools
self._tool_overrides = tool_overrides or {}
self._raise_on_error = raise_on_error

# Build reverse mapping from override names to original names for call_tool
self._override_name_to_original: Dict[str, str] = {}
Expand Down Expand Up @@ -119,6 +126,8 @@ async def call_tool(
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
if self._raise_on_error:
raise
result_str = self._format_errors(e)
is_error = True
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
Expand Down Expand Up @@ -205,6 +214,8 @@ async def call_tool_stream(
previous_result = result
actual_tool_output = previous_result
except Exception as e:
if self._raise_on_error:
raise
# If there was a previous result before the exception, yield it first
if previous_result is not None:
yield previous_result
Expand All @@ -220,6 +231,8 @@ async def call_tool_stream(
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
if self._raise_on_error:
raise
result_str = self._format_errors(e)
is_error = True
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)