diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index 83636e8323..26fac2b8cb 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -169,6 +169,7 @@ print(result.all_messages()) content='Deleting files is not allowed', tool_call_id='delete_file', timestamp=datetime.datetime(...), + outcome='denied', ), UserPromptPart( content='Now create a backup of README.md', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 13021f326f..ff4c5db6ec 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -1340,6 +1340,7 @@ async def _call_tool( tool_name=call.tool_name, content=tool_call_result.message, tool_call_id=call.tool_call_id, + outcome='denied', ), None elif isinstance(tool_call_result, exceptions.ModelRetry): m = _messages.RetryPromptPart( diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 56baf423ea..29191ae5a1 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1024,6 +1024,14 @@ class BaseToolReturnPart: timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the tool returned.""" + outcome: Literal['success', 'failed', 'denied'] = 'success' + """The outcome of the tool call. + + - `'success'`: The tool executed successfully. + - `'failed'`: The tool raised an error during execution. + - `'denied'`: The tool call was denied by the approval mechanism. + """ + def model_response_str(self) -> str: """Return a string representation of the content for the model.""" if isinstance(self.content, str): diff --git a/pydantic_ai_slim/pydantic_ai/ui/__init__.py b/pydantic_ai_slim/pydantic_ai/ui/__init__.py index 683d6d7335..52b9404d7d 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ui/__init__.py @@ -1,9 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ._adapter import StateDeps, StateHandler, UIAdapter from ._event_stream import SSE_CONTENT_TYPE, NativeEvent, OnCompleteFunc, UIEventStream from ._messages_builder import MessagesBuilder -from ._web import DEFAULT_HTML_URL + +if TYPE_CHECKING: + from ._web import DEFAULT_HTML_URL __all__ = [ 'UIAdapter', @@ -16,3 +20,11 @@ 'MessagesBuilder', 'DEFAULT_HTML_URL', ] + + +def __getattr__(name: str) -> object: + if name == 'DEFAULT_HTML_URL': + from ._web import DEFAULT_HTML_URL + + return DEFAULT_HTML_URL + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index b0647c904b..b3f70660ba 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -58,8 +58,10 @@ SourceUrlUIPart, StepStartUIPart, TextUIPart, + ToolApprovalResponded, ToolInputAvailablePart, ToolOutputAvailablePart, + ToolOutputDeniedPart, ToolOutputErrorPart, ToolUIPart, UIMessage, @@ -343,7 +345,9 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # # The call and return metadata are combined in the output part. # So we extract and return them to the respective parts call_meta = return_meta = {} - has_tool_output = isinstance(part, (ToolOutputAvailablePart, ToolOutputErrorPart)) + has_tool_output = isinstance( + part, (ToolOutputAvailablePart, ToolOutputErrorPart, ToolOutputDeniedPart) + ) if has_tool_output: call_meta, return_meta = cls._load_builtin_tool_meta(provider_meta) @@ -360,11 +364,15 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # ) if has_tool_output: - output: Any | None = None - if isinstance(part, ToolOutputAvailablePart): - output = part.output - elif isinstance(part, ToolOutputErrorPart): # pragma: no branch - output = {'error_text': part.error_text, 'is_error': True} + if isinstance(part, ToolOutputErrorPart): + output: Any = part.error_text + outcome: Literal['success', 'failed', 'denied'] = 'failed' + elif isinstance(part, ToolOutputDeniedPart): + output = _denial_reason(part) + outcome = 'denied' + else: + output = part.output if isinstance(part, ToolOutputAvailablePart) else None + outcome = 'success' builder.add( BuiltinToolReturnPart( tool_name=tool_name, @@ -372,6 +380,7 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # content=output, provider_name=return_meta.get('provider_name') or provider_name, provider_details=return_meta.get('provider_details') or provider_details, + outcome=outcome, ) ) else: @@ -392,8 +401,20 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # ) elif part.state == 'output-error': builder.add( - RetryPromptPart( - tool_name=tool_name, tool_call_id=tool_call_id, content=part.error_text + ToolReturnPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=part.error_text, + outcome='failed', + ) + ) + elif part.state == 'output-denied': + builder.add( + ToolReturnPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=_denial_reason(part), + outcome='denied', ) ) elif isinstance(part, DataUIPart): # pragma: no cover @@ -522,20 +543,33 @@ def _dump_response_message( ) combined_provider_meta = cls._dump_builtin_tool_meta(call_meta, return_meta) - response_object = builtin_return.model_response_object() - # These `is_error`/`error_text` fields are only present when the BuiltinToolReturnPart - # was parsed from an incoming VercelAI request. We can't detect errors for other sources - # until BuiltinToolReturnPart has standardized error fields (see https://github.com/pydantic/pydantic-ai/issues/3561).3 - if response_object.get('is_error') is True and ( - (error_text := response_object.get('error_text')) is not None + if builtin_return.outcome == 'denied': + ui_parts.append( + ToolOutputDeniedPart( + type=tool_name, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + provider_executed=True, + call_provider_metadata=combined_provider_meta, + approval=ToolApprovalResponded( + id=str(uuid.uuid4()), + approved=False, + reason=builtin_return.model_response_str(), + ), + ) + ) + elif ( + builtin_return.outcome == 'failed' + or builtin_return.model_response_object().get('is_error') is True ): + response_obj = builtin_return.model_response_object() + error_text = response_obj.get('error_text', builtin_return.model_response_str()) ui_parts.append( ToolOutputErrorPart( type=tool_name, tool_call_id=part.tool_call_id, input=_safe_args_as_dict(part), error_text=error_text, - state='output-error', provider_executed=True, call_provider_metadata=combined_provider_meta, ) @@ -547,7 +581,6 @@ def _dump_response_message( tool_call_id=part.tool_call_id, input=_safe_args_as_dict(part), output=tool_return_output(builtin_return), - state='output-available', provider_executed=True, call_provider_metadata=combined_provider_meta, ) @@ -561,58 +594,90 @@ def _dump_response_message( type=tool_name, tool_call_id=part.tool_call_id, input=_safe_args_as_dict(part), - state='input-available', provider_executed=True, call_provider_metadata=call_provider_metadata, ) ) elif isinstance(part, ToolCallPart): - tool_result = tool_results.get(part.tool_call_id) - call_provider_metadata = dump_provider_metadata( - id=part.id, provider_name=part.provider_name, provider_details=part.provider_details - ) - tool_type = f'tool-{part.tool_name}' + ui_parts.extend(cls._dump_tool_call_part(part, tool_results)) + else: + assert_never(part) - if isinstance(tool_result, ToolReturnPart): - ui_parts.append( - ToolOutputAvailablePart( - type=tool_type, - tool_call_id=part.tool_call_id, - input=_safe_args_as_dict(part), - output=tool_return_output(tool_result), - state='output-available', - provider_executed=False, - call_provider_metadata=call_provider_metadata, - ) - ) - # Check for Vercel AI chunks returned by tool calls via metadata. - ui_parts.extend(_extract_metadata_ui_parts(tool_result)) - elif isinstance(tool_result, RetryPromptPart): - error_text = tool_result.model_response() - ui_parts.append( - ToolOutputErrorPart( - type=tool_type, - tool_call_id=part.tool_call_id, - input=_safe_args_as_dict(part), - error_text=error_text, - state='output-error', - provider_executed=False, - call_provider_metadata=call_provider_metadata, - ) + return ui_parts + + @staticmethod + def _dump_tool_call_part( + part: ToolCallPart, tool_results: dict[str, ToolReturnPart | RetryPromptPart] + ) -> list[UIMessagePart]: + """Convert a ToolCallPart (with optional result) into UIMessageParts.""" + tool_result = tool_results.get(part.tool_call_id) + call_provider_metadata = dump_provider_metadata( + id=part.id, provider_name=part.provider_name, provider_details=part.provider_details + ) + tool_type = f'tool-{part.tool_name}' + ui_parts: list[UIMessagePart] = [] + + if isinstance(tool_result, ToolReturnPart): + if tool_result.outcome == 'denied': + ui_parts.append( + ToolOutputDeniedPart( + type=tool_type, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + provider_executed=False, + call_provider_metadata=call_provider_metadata, + approval=ToolApprovalResponded( + id=str(uuid.uuid4()), + approved=False, + reason=tool_result.model_response_str(), + ), ) - else: - ui_parts.append( - ToolInputAvailablePart( - type=tool_type, - tool_call_id=part.tool_call_id, - input=_safe_args_as_dict(part), - state='input-available', - provider_executed=False, - call_provider_metadata=call_provider_metadata, - ) + ) + elif tool_result.outcome == 'failed': + ui_parts.append( + ToolOutputErrorPart( + type=tool_type, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + error_text=tool_result.model_response_str(), + provider_executed=False, + call_provider_metadata=call_provider_metadata, ) + ) else: - assert_never(part) + ui_parts.append( + ToolOutputAvailablePart( + type=tool_type, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + output=tool_return_output(tool_result), + provider_executed=False, + call_provider_metadata=call_provider_metadata, + ) + ) + # Check for Vercel AI chunks returned by tool calls via metadata. + ui_parts.extend(_extract_metadata_ui_parts(tool_result)) + elif isinstance(tool_result, RetryPromptPart): + ui_parts.append( + ToolOutputErrorPart( + type=tool_type, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + error_text=tool_result.model_response(), + provider_executed=False, + call_provider_metadata=call_provider_metadata, + ) + ) + else: + ui_parts.append( + ToolInputAvailablePart( + type=tool_type, + tool_call_id=part.tool_call_id, + input=_safe_args_as_dict(part), + provider_executed=False, + call_provider_metadata=call_provider_metadata, + ) + ) return ui_parts @@ -715,6 +780,13 @@ def _convert_user_prompt_part(part: UserPromptPart) -> list[UIMessagePart]: return ui_parts +def _denial_reason(part: ToolUIPart | DynamicToolUIPart) -> str: + """Extract the denial reason from a tool part's approval, or return a default message.""" + if isinstance(part.approval, ToolApprovalResponded) and part.approval.reason: + return part.approval.reason + return ToolDenied().message + + def _extract_metadata_ui_parts(tool_result: ToolReturnPart) -> list[UIMessagePart]: """Convert data-carrying chunks from tool metadata into UIMessageParts. diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py index d6e960b5dc..bc2592c217 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py @@ -4,7 +4,6 @@ from collections.abc import AsyncIterator, Mapping from dataclasses import KW_ONLY, dataclass -from functools import cached_property from typing import Any, Literal from uuid import uuid4 @@ -29,7 +28,7 @@ from ...run import AgentRunResultEvent from ...tools import AgentDepsT, DeferredToolRequests from .. import UIEventStream -from ._utils import dump_provider_metadata, iter_metadata_chunks, iter_tool_approval_responses, tool_return_output +from ._utils import dump_provider_metadata, iter_metadata_chunks, tool_return_output from .request_types import RequestData from .response_types import ( BaseChunk, @@ -87,15 +86,6 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp _step_started: bool = False _finish_reason: FinishReason = None - @cached_property - def _denied_tool_ids(self) -> set[str]: - """Get the set of tool_call_ids that were denied by the user.""" - return { - tool_call_id - for tool_call_id, approval in iter_tool_approval_responses(self.run_input.messages) - if not approval.approved - } - @property def response_headers(self) -> Mapping[str, str] | None: return VERCEL_AI_DSP_HEADERS @@ -257,11 +247,16 @@ async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> Async ) async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseChunk]: - yield ToolOutputAvailableChunk( - tool_call_id=part.tool_call_id, - output=tool_return_output(part), - provider_executed=True, - ) + if self.sdk_version >= 6 and part.outcome == 'denied': + yield ToolOutputDeniedChunk(tool_call_id=part.tool_call_id) + elif part.outcome == 'failed': + yield ToolOutputErrorChunk(tool_call_id=part.tool_call_id, error_text=part.model_response_str()) + else: + yield ToolOutputAvailableChunk( + tool_call_id=part.tool_call_id, + output=tool_return_output(part), + provider_executed=True, + ) async def handle_file(self, part: FilePart) -> AsyncIterator[BaseChunk]: file = part.content @@ -271,11 +266,12 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A part = event.result tool_call_id = part.tool_call_id - # Check if this tool was denied by the user (only when sdk_version >= 6) - if self.sdk_version >= 6 and tool_call_id in self._denied_tool_ids: + if self.sdk_version >= 6 and isinstance(part, ToolReturnPart) and part.outcome == 'denied': yield ToolOutputDeniedChunk(tool_call_id=tool_call_id) elif isinstance(part, RetryPromptPart): yield ToolOutputErrorChunk(tool_call_id=tool_call_id, error_text=part.model_response()) + elif isinstance(part, ToolReturnPart) and part.outcome == 'failed': + yield ToolOutputErrorChunk(tool_call_id=tool_call_id, error_text=part.model_response_str()) else: yield ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=tool_return_output(part)) diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py index 353df3390f..055b185ed7 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py @@ -5,14 +5,20 @@ from pydantic_ai.messages import BaseToolReturnPart, ProviderDetailsDelta, ToolReturnPart from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolApprovalRequestedPart, + DynamicToolApprovalRespondedPart, DynamicToolInputAvailablePart, DynamicToolInputStreamingPart, DynamicToolOutputAvailablePart, + DynamicToolOutputDeniedPart, DynamicToolOutputErrorPart, + ToolApprovalRequestedPart, ToolApprovalResponded, + ToolApprovalRespondedPart, ToolInputAvailablePart, ToolInputStreamingPart, ToolOutputAvailablePart, + ToolOutputDeniedPart, ToolOutputErrorPart, UIMessage, ) @@ -110,19 +116,36 @@ def iter_metadata_chunks( ToolInputAvailablePart, ToolOutputAvailablePart, ToolOutputErrorPart, + ToolApprovalRequestedPart, + ToolApprovalRespondedPart, + ToolOutputDeniedPart, DynamicToolInputStreamingPart, DynamicToolInputAvailablePart, DynamicToolOutputAvailablePart, DynamicToolOutputErrorPart, + DynamicToolApprovalRequestedPart, + DynamicToolApprovalRespondedPart, + DynamicToolOutputDeniedPart, +) + + +_APPROVAL_RESPONDED_TYPES = ( + ToolApprovalRespondedPart, + DynamicToolApprovalRespondedPart, ) def iter_tool_approval_responses( messages: list[UIMessage], ) -> Iterator[tuple[str, ToolApprovalResponded]]: - """Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages.""" + """Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages. + + Only ``approval-responded`` parts are matched. ``output-denied`` parts have + already been materialized into the message history by ``load_messages()`` and + must not be re-processed as deferred results. + """ for msg in messages: if msg.role == 'assistant': for part in msg.parts: - if isinstance(part, _TOOL_PART_TYPES) and isinstance(part.approval, ToolApprovalResponded): + if isinstance(part, _APPROVAL_RESPONDED_TYPES) and isinstance(part.approval, ToolApprovalResponded): yield part.tool_call_id, part.approval diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py index a49c84224e..804fb468c2 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py @@ -189,7 +189,51 @@ class ToolOutputErrorPart(BaseUIPart): approval: ToolApproval | None = None -ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart +class ToolApprovalRequestedPart(BaseUIPart): + """Tool part in approval-requested state (awaiting user decision).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['approval-requested'] = 'approval-requested' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class ToolApprovalRespondedPart(BaseUIPart): + """Tool part in approval-responded state (user approved/denied, execution pending).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['approval-responded'] = 'approval-responded' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class ToolOutputDeniedPart(BaseUIPart): + """Tool part in output-denied state (tool was denied, terminal state).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['output-denied'] = 'output-denied' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +ToolUIPart = ( + ToolInputStreamingPart + | ToolInputAvailablePart + | ToolOutputAvailablePart + | ToolOutputErrorPart + | ToolApprovalRequestedPart + | ToolApprovalRespondedPart + | ToolOutputDeniedPart +) """Union of all tool part types.""" @@ -245,11 +289,50 @@ class DynamicToolOutputErrorPart(BaseUIPart): approval: ToolApproval | None = None +class DynamicToolApprovalRequestedPart(BaseUIPart): + """Dynamic tool part in approval-requested state (awaiting user decision).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['approval-requested'] = 'approval-requested' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class DynamicToolApprovalRespondedPart(BaseUIPart): + """Dynamic tool part in approval-responded state (user approved/denied, execution pending).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['approval-responded'] = 'approval-responded' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class DynamicToolOutputDeniedPart(BaseUIPart): + """Dynamic tool part in output-denied state (tool was denied, terminal state).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['output-denied'] = 'output-denied' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + DynamicToolUIPart = ( DynamicToolInputStreamingPart | DynamicToolInputAvailablePart | DynamicToolOutputAvailablePart | DynamicToolOutputErrorPart + | DynamicToolApprovalRequestedPart + | DynamicToolApprovalRespondedPart + | DynamicToolOutputDeniedPart ) """Union of all dynamic tool part types.""" diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 1d092bd3bb..50af851cf4 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -262,6 +262,7 @@ def test_var_args(): 'tool_call_id': IsStr(), 'metadata': None, 'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc), # type: ignore[reportUnknownMemberType] + 'outcome': 'success', 'part_kind': 'tool-return', } ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 47e3d5c781..8f07fbac30 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6965,6 +6965,7 @@ def create_file(path: str, content: str) -> str: content='File cannot be deleted', tool_call_id='never_delete', timestamp=IsDatetime(), + outcome='denied', ), ], timestamp=IsNow(tz=timezone.utc), @@ -6996,6 +6997,7 @@ def create_file(path: str, content: str) -> str: content='File cannot be deleted', tool_call_id='never_delete', timestamp=IsDatetime(), + outcome='denied', ), ], timestamp=IsNow(tz=timezone.utc), diff --git a/tests/test_tools.py b/tests/test_tools.py index c6f467e8d7..fe178006ae 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -2317,6 +2317,7 @@ def bar(x: int) -> int: content='The tool call was denied.', tool_call_id='foo2', timestamp=IsDatetime(), + outcome='denied', ), ], timestamp=IsDatetime(), diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 2a9d6f0be7..c5feeaba0e 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -17,6 +17,7 @@ BuiltinToolReturnPart, DocumentUrl, FilePart, + FunctionToolResultEvent, ImageUrl, ModelMessage, ModelRequest, @@ -58,10 +59,17 @@ from starlette.responses import StreamingResponse from pydantic_ai.ui.vercel_ai import VercelAIAdapter, VercelAIEventStream - from pydantic_ai.ui.vercel_ai._utils import dump_provider_metadata, load_provider_metadata + from pydantic_ai.ui.vercel_ai._utils import ( + dump_provider_metadata, + iter_tool_approval_responses, + load_provider_metadata, + ) from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolApprovalRespondedPart, DynamicToolInputAvailablePart, + DynamicToolInputStreamingPart, DynamicToolOutputAvailablePart, + DynamicToolOutputDeniedPart, FileUIPart, ReasoningUIPart, SubmitMessage, @@ -69,7 +77,9 @@ ToolApprovalRequested, ToolApprovalResponded, ToolInputAvailablePart, + ToolInputStreamingPart, ToolOutputAvailablePart, + ToolOutputDeniedPart, ToolOutputErrorPart, UIMessage, ) @@ -2427,13 +2437,13 @@ def delete_file(path: str) -> str: role='assistant', parts=[ TextUIPart(text='I will delete the file for you.'), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_approved', input={'path': 'approved.txt'}, approval=ToolApprovalResponded(id='approval-456', approved=True), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_1', input={'path': 'test.txt'}, @@ -2514,6 +2524,7 @@ def capture_result(r: AgentRunResult[Any]) -> None: content='User cancelled the deletion', tool_call_id='delete_1', timestamp=IsDatetime(), + outcome='denied', ), ], timestamp=IsDatetime(), @@ -2559,7 +2570,7 @@ def some_tool(x: str) -> str: input={'x': 'no_approval'}, approval=None, ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='some_tool', tool_call_id='approved_tool', input={'x': 'approved'}, @@ -2625,7 +2636,7 @@ def delete_file(path: str) -> str: id='assistant-1', role='assistant', parts=[ - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_1', input={'path': 'important.txt'}, @@ -2633,13 +2644,13 @@ def delete_file(path: str) -> str: id='denial-id', approved=False, reason='User cancelled the deletion' ), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_2', input={'path': 'temp.txt'}, approval=ToolApprovalResponded(id='denial-no-reason', approved=False), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_3', input={'path': 'ok.txt'}, @@ -2659,6 +2670,39 @@ def delete_file(path: str) -> str: assert approvals['delete_3'] is True +async def test_tool_approval_ignores_output_denied_parts(): + """Test that output-denied parts are not yielded by iter_tool_approval_responses. + + When a denied tool is retried, the assistant message accumulates both an + output-denied part (terminal, already materialized by load_messages) and an + approval-responded part (pending, needs deferred handling). Only the latter + should be extracted. + """ + messages = [ + UIMessage( + id='assistant-1', + role='assistant', + parts=[ + DynamicToolOutputDeniedPart( + tool_name='delete_file', + tool_call_id='tool_A', + input={'path': 'first.txt'}, + approval=ToolApprovalResponded(id='deny-A', approved=False, reason='Not allowed'), + ), + DynamicToolApprovalRespondedPart( + tool_name='delete_file', + tool_call_id='tool_B', + input={'path': 'second.txt'}, + approval=ToolApprovalResponded(id='deny-B', approved=False), + ), + ], + ) + ] + + results = dict(iter_tool_approval_responses(messages)) + assert results == {'tool_B': ToolApprovalResponded(id='deny-B', approved=False)} + + async def test_run_stream_with_deferred_tool_results_no_model_response(): """Test that run_stream errors when deferred_tool_results is passed without a ModelResponse in history.""" agent = Agent(model=TestModel()) @@ -3167,11 +3211,12 @@ async def test_adapter_load_messages(): ), ModelRequest( parts=[ - RetryPromptPart( - content="Can't do that", + ToolReturnPart( tool_name='get_table_of_contents', + content="Can't do that", tool_call_id='toolu_01W2yGpGQcMx7pXV2zZ4sz9g', timestamp=IsDatetime(), + outcome='failed', ) ] ), @@ -3198,10 +3243,11 @@ async def test_adapter_load_messages(): ), BuiltinToolReturnPart( tool_name='web_search', - content={'error_text': "Can't do that", 'is_error': True}, + content="Can't do that", tool_call_id='toolu_01W2yGpGQcMx7pXV2z', timestamp=IsDatetime(), provider_name='openai', + outcome='failed', ), TextPart( content='Here are the Table of Contents for both repositories:... Both products are designed to work together - Pydantic AI for building AI agents and Logfire for observing and monitoring them in production.' @@ -3939,23 +3985,20 @@ async def test_adapter_dump_messages_with_retry(): ] ) - # Verify roundtrip + # Verify roundtrip — load_messages now produces ToolReturnPart(outcome='failed') + # instead of RetryPromptPart for tool errors from the Vercel AI format reloaded_messages = VercelAIAdapter.load_messages(ui_messages) - # Content will have changed for retry prompt part, so we check it's value - # And then set it back to the original value - retry_prompt_part = reloaded_messages[2].parts[0] - assert isinstance(retry_prompt_part, RetryPromptPart) - assert retry_prompt_part == snapshot( - RetryPromptPart( - content='Tool failed with error\n\nFix the errors and try again.', + tool_error_part = reloaded_messages[2].parts[0] + assert isinstance(tool_error_part, ToolReturnPart) + assert tool_error_part == snapshot( + ToolReturnPart( tool_name='my_tool', + content='Tool failed with error\n\nFix the errors and try again.', tool_call_id='tool_789', timestamp=IsDatetime(), + outcome='failed', ) ) - retry_prompt_part.content = 'Tool failed with error' - _sync_timestamps(messages, reloaded_messages) - assert reloaded_messages == messages async def test_adapter_dump_messages_with_retry_no_tool_name(): @@ -4970,8 +5013,6 @@ async def test_adapter_tool_call_part_with_provider_metadata(): async def test_adapter_load_messages_tool_call_with_provider_metadata(): """Test loading dynamic tool part with provider_metadata preserves metadata on ToolCallPart.""" - from pydantic_ai.ui.vercel_ai.request_types import DynamicToolInputAvailablePart - ui_messages = [ UIMessage( id='msg1', @@ -5194,10 +5235,11 @@ async def test_adapter_builtin_tool_error_part_with_provider_metadata(): ), BuiltinToolReturnPart( tool_name='web_search', - content={'error_text': 'Search failed: rate limit exceeded', 'is_error': True}, + content='Search failed: rate limit exceeded', tool_call_id='bt_err_123', provider_name='openai', provider_details={'error_code': 'RATE_LIMIT'}, + outcome='failed', ), ] ), @@ -5361,11 +5403,12 @@ async def test_adapter_load_messages_builtin_tool_error_with_provider_details(): ), BuiltinToolReturnPart( tool_name='web_search', - content={'error_text': 'Search failed: rate limit exceeded', 'is_error': True}, + content='Search failed: rate limit exceeded', tool_call_id='bt_error', timestamp=IsDatetime(), provider_name='openai', provider_details={'error_code': 'RATE_LIMIT'}, + outcome='failed', ), ], timestamp=IsDatetime(), @@ -5376,8 +5419,6 @@ async def test_adapter_load_messages_builtin_tool_error_with_provider_details(): async def test_adapter_load_messages_tool_input_streaming_part(): """Test loading ToolInputStreamingPart which doesn't have call_provider_metadata yet.""" - from pydantic_ai.ui.vercel_ai.request_types import ToolInputStreamingPart - ui_messages = [ UIMessage( id='msg1', @@ -5408,8 +5449,6 @@ async def test_adapter_load_messages_tool_input_streaming_part(): async def test_adapter_load_messages_dynamic_tool_input_streaming_part(): """Test loading DynamicToolInputStreamingPart which doesn't have call_provider_metadata yet.""" - from pydantic_ai.ui.vercel_ai.request_types import DynamicToolInputStreamingPart - ui_messages = [ UIMessage( id='msg1', @@ -5507,15 +5546,12 @@ async def test_adapter_dump_messages_tool_error_with_provider_metadata(): ] ) - # Verify roundtrip + # Verify roundtrip — load_messages now produces ToolReturnPart(outcome='failed') reloaded_messages = VercelAIAdapter.load_messages(ui_messages) - # Content will have changed for retry prompt part, so we set it back to the original value - retry_prompt_part = reloaded_messages[2].parts[0] - assert isinstance(retry_prompt_part, RetryPromptPart) - assert retry_prompt_part.content == 'Tool execution failed\n\nFix the errors and try again.' - retry_prompt_part.content = 'Tool execution failed' - _sync_timestamps(messages, reloaded_messages) - assert reloaded_messages == messages + tool_error_part = reloaded_messages[2].parts[0] + assert isinstance(tool_error_part, ToolReturnPart) + assert tool_error_part.outcome == 'failed' + assert tool_error_part.content == 'Tool execution failed\n\nFix the errors and try again.' async def test_event_stream_text_with_provider_metadata(): @@ -5935,6 +5971,222 @@ async def event_generator(): ) +async def test_event_stream_builtin_tool_return_denied(): + """Test that ToolOutputDeniedChunk is emitted for a denied BuiltinToolReturnPart.""" + + async def event_generator(): + yield PartStartEvent( + index=0, + part=BuiltinToolReturnPart( + tool_name='web_search', + tool_call_id='tc_denied', + content='Blocked by policy', + outcome='denied', + ), + ) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Search')], + ), + ], + ) + event_stream = VercelAIEventStream(run_input=request, sdk_version=6) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator())) + ] + + assert events == snapshot( + [ + {'type': 'start'}, + {'type': 'start-step'}, + {'type': 'tool-output-denied', 'toolCallId': 'tc_denied'}, + {'type': 'finish-step'}, + {'type': 'finish'}, + '[DONE]', + ] + ) + + +async def test_event_stream_builtin_tool_return_error(): + async def event_generator(): + yield PartStartEvent( + index=0, + part=BuiltinToolReturnPart( + tool_name='web_search', + tool_call_id='tc_err', + content='Search failed', + outcome='failed', + ), + ) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Search')], + ), + ], + ) + event_stream = VercelAIEventStream(run_input=request, sdk_version=6) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator())) + ] + + assert events == snapshot( + [ + {'type': 'start'}, + {'type': 'start-step'}, + { + 'type': 'tool-output-error', + 'toolCallId': 'tc_err', + 'errorText': 'Search failed', + }, + {'type': 'finish-step'}, + {'type': 'finish'}, + '[DONE]', + ] + ) + + +async def test_adapter_dump_messages_tool_return_error(): + """Test that ToolReturnPart(outcome='failed') dumps as ToolOutputErrorPart.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Do something')]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id='tc_err'), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='my_tool', + content='Something went wrong', + tool_call_id='tc_err', + outcome='failed', + ), + ] + ), + ] + + ui_messages = VercelAIAdapter.dump_messages(messages) + assistant_parts = [msg.model_dump() for msg in ui_messages if msg.role == 'assistant'][0]['parts'] + assert assistant_parts == snapshot( + [ + { + 'type': 'tool-my_tool', + 'tool_call_id': 'tc_err', + 'state': 'output-error', + 'raw_input': None, + 'input': {'x': 1}, + 'error_text': 'Something went wrong', + 'provider_executed': False, + 'call_provider_metadata': None, + 'approval': None, + } + ] + ) + + # Verify roundtrip + reloaded = VercelAIAdapter.load_messages(ui_messages) + error_part = reloaded[2].parts[0] + assert isinstance(error_part, ToolReturnPart) + assert error_part.outcome == 'failed' + assert error_part.content == 'Something went wrong' + + +async def test_adapter_dump_messages_builtin_tool_error_backward_compat(): + """Test that old-format BuiltinToolReturnPart with is_error content is still detected as error.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Search')]), + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + args={'query': 'test'}, + tool_call_id='bt_old', + ), + BuiltinToolReturnPart( + tool_name='web_search', + content={'error_text': 'Rate limit exceeded', 'is_error': True}, + tool_call_id='bt_old', + ), + ] + ), + ] + + ui_messages = VercelAIAdapter.dump_messages(messages) + assistant_parts = [msg.model_dump() for msg in ui_messages if msg.role == 'assistant'][0]['parts'] + assert assistant_parts == snapshot( + [ + { + 'type': 'tool-web_search', + 'tool_call_id': 'bt_old', + 'state': 'output-error', + 'raw_input': None, + 'input': {'query': 'test'}, + 'error_text': 'Rate limit exceeded', + 'provider_executed': True, + 'call_provider_metadata': None, + 'approval': None, + } + ] + ) + + +async def test_event_stream_function_tool_return_error(): + """Test that ToolOutputErrorChunk is emitted for ToolReturnPart(outcome='failed').""" + + async def event_generator(): + yield FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='my_tool', + content='Something went wrong', + tool_call_id='tc_err', + outcome='failed', + ), + ) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Do something')], + ), + ], + ) + event_stream = VercelAIEventStream(run_input=request, sdk_version=6) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator())) + ] + + assert events == snapshot( + [ + {'type': 'start'}, + { + 'type': 'tool-output-error', + 'toolCallId': 'tc_err', + 'errorText': 'Something went wrong', + }, + {'type': 'finish-step'}, + {'type': 'finish'}, + '[DONE]', + ] + ) + + def _sync_timestamps(original: list[ModelMessage], new: list[ModelMessage]) -> None: """Utility function to sync timestamps between original and new messages.""" for orig_msg, new_msg in zip(original, new): @@ -6067,3 +6319,182 @@ async def event_generator(): e for e in events_v6 if isinstance(e, dict) and e.get('type') == 'tool-input-start' ) assert 'providerMetadata' in tool_input_start_v6 + + +@pytest.mark.parametrize( + ('reason', 'expected_content'), + [ + pytest.param('Too dangerous', 'Too dangerous', id='explicit-reason'), + pytest.param(None, 'The tool call was denied.', id='default-reason'), + ], +) +async def test_adapter_load_messages_output_denied(reason: str | None, expected_content: str): + ui_messages = [ + UIMessage( + id='msg1', + role='assistant', + parts=[ + DynamicToolOutputDeniedPart( + tool_name='delete_file', + tool_call_id='tc_denied', + input={'path': 'important.txt'}, + approval=ToolApprovalResponded(id='deny-1', approved=False, reason=reason), + ), + ], + ) + ] + + messages = VercelAIAdapter.load_messages(ui_messages) + assert messages == [ + ModelResponse( + parts=[ToolCallPart(tool_name='delete_file', args={'path': 'important.txt'}, tool_call_id='tc_denied')], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='delete_file', + content=expected_content, + tool_call_id='tc_denied', + timestamp=IsDatetime(), + outcome='denied', + ) + ] + ), + ] + + +async def test_adapter_load_messages_output_denied_builtin_tool(): + ui_messages = [ + UIMessage( + id='msg1', + role='assistant', + parts=[ + ToolOutputDeniedPart( + type='tool-web_search', + tool_call_id='tc_builtin_denied', + input={'query': 'secret data'}, + provider_executed=True, + approval=ToolApprovalResponded(id='deny-2', approved=False, reason='Blocked by policy'), + ), + ], + ) + ] + + messages = VercelAIAdapter.load_messages(ui_messages) + assert messages == snapshot( + [ + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + args={'query': 'secret data'}, + tool_call_id='tc_builtin_denied', + ), + BuiltinToolReturnPart( + tool_name='web_search', + content='Blocked by policy', + tool_call_id='tc_builtin_denied', + timestamp=IsDatetime(), + outcome='denied', + ), + ], + timestamp=IsDatetime(), + ) + ] + ) + + +async def test_denied_dynamic_tool_round_trip(): + """Test that denied dynamic tool state survives a dump/load cycle.""" + + messages: list[ModelMessage] = [ + ModelResponse( + parts=[ToolCallPart(tool_name='delete_file', args={'path': '/tmp/x'}, tool_call_id='tc1')], + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='delete_file', content='Too dangerous', tool_call_id='tc1', outcome='denied') + ], + ), + ] + + ui_messages = VercelAIAdapter.dump_messages(messages) + + # The denied tool should produce a ToolOutputDeniedPart with the reason preserved + assistant_parts = ui_messages[0].parts + assert len(assistant_parts) == 1 + assert isinstance(assistant_parts[0], ToolOutputDeniedPart) + assert assistant_parts[0].state == 'output-denied' + assert isinstance(assistant_parts[0].approval, ToolApprovalResponded) + assert assistant_parts[0].approval.reason == 'Too dangerous' + + # Round-trip back: the denial reason is preserved via approval.reason + loaded = VercelAIAdapter.load_messages(ui_messages) + assert loaded == snapshot( + [ + ModelResponse( + parts=[ToolCallPart(tool_name='delete_file', args={'path': '/tmp/x'}, tool_call_id='tc1')], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='delete_file', + content='Too dangerous', + tool_call_id='tc1', + timestamp=IsDatetime(), + outcome='denied', + ) + ] + ), + ] + ) + + +async def test_denied_builtin_tool_round_trip(): + """Test that denied builtin tool state survives a dump/load cycle.""" + + messages: list[ModelMessage] = [ + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='web_search', args={'query': 'secret'}, tool_call_id='tc2'), + BuiltinToolReturnPart( + tool_name='web_search', + content='Blocked by policy', + tool_call_id='tc2', + outcome='denied', + ), + ], + ), + ] + + ui_messages = VercelAIAdapter.dump_messages(messages) + + # The denied builtin tool should produce a ToolOutputDeniedPart with the reason preserved + assistant_parts = ui_messages[0].parts + assert len(assistant_parts) == 1 + assert isinstance(assistant_parts[0], ToolOutputDeniedPart) + assert assistant_parts[0].state == 'output-denied' + assert isinstance(assistant_parts[0].approval, ToolApprovalResponded) + assert assistant_parts[0].approval.reason == 'Blocked by policy' + + # Round-trip back + loaded = VercelAIAdapter.load_messages(ui_messages) + assert loaded == snapshot( + [ + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='web_search', args={'query': 'secret'}, tool_call_id='tc2'), + BuiltinToolReturnPart( + tool_name='web_search', + content='Blocked by policy', + tool_call_id='tc2', + timestamp=IsDatetime(), + outcome='denied', + ), + ], + timestamp=IsDatetime(), + ) + ] + )