Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
SourceUrlUIPart,
StepStartUIPart,
TextUIPart,
ToolApprovalResponded,
ToolInputAvailablePart,
ToolOutputAvailablePart,
ToolOutputDeniedPart,
ToolOutputErrorPart,
ToolUIPart,
UIMessage,
Expand Down Expand Up @@ -295,7 +297,9 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
# The call and return metadata are combined in the output part.
# So we extract and return them to the respective parts
call_meta = return_meta = {}
has_tool_output = isinstance(part, (ToolOutputAvailablePart, ToolOutputErrorPart))
has_tool_output = isinstance(
part, (ToolOutputAvailablePart, ToolOutputErrorPart, ToolOutputDeniedPart)
)

if has_tool_output:
call_meta, return_meta = cls._load_builtin_tool_meta(provider_meta)
Expand All @@ -315,8 +319,10 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
output: Any | None = None
if isinstance(part, ToolOutputAvailablePart):
output = part.output
elif isinstance(part, ToolOutputErrorPart): # pragma: no branch
elif isinstance(part, ToolOutputErrorPart):
output = {'error_text': part.error_text, 'is_error': True}
elif isinstance(part, ToolOutputDeniedPart): # pragma: no branch
output = _denial_reason(part)
builder.add(
BuiltinToolReturnPart(
tool_name=tool_name,
Expand Down Expand Up @@ -348,6 +354,14 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
tool_name=tool_name, tool_call_id=tool_call_id, content=part.error_text
)
)
elif part.state == 'output-denied':
builder.add(
ToolReturnPart(
tool_name=tool_name,
tool_call_id=tool_call_id,
content=_denial_reason(part),
)
)
elif isinstance(part, DataUIPart): # pragma: no cover
# Contains custom data that shouldn't be sent to the model
pass
Expand Down Expand Up @@ -635,6 +649,13 @@ def _convert_user_prompt_part(part: UserPromptPart) -> list[UIMessagePart]:
return ui_parts


def _denial_reason(part: ToolUIPart | DynamicToolUIPart) -> str:
"""Extract the denial reason from a tool part's approval, or return a default message."""
if isinstance(part.approval, ToolApprovalResponded) and part.approval.reason:
return part.approval.reason
return 'Tool call was denied.'


def _extract_metadata_ui_parts(tool_result: ToolReturnPart) -> list[UIMessagePart]:
"""Convert data-carrying chunks from tool metadata into UIMessageParts.

Expand Down
27 changes: 25 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@

from pydantic_ai.messages import ProviderDetailsDelta, ToolReturnPart
from pydantic_ai.ui.vercel_ai.request_types import (
DynamicToolApprovalRequestedPart,
DynamicToolApprovalRespondedPart,
DynamicToolInputAvailablePart,
DynamicToolInputStreamingPart,
DynamicToolOutputAvailablePart,
DynamicToolOutputDeniedPart,
DynamicToolOutputErrorPart,
ToolApprovalRequestedPart,
ToolApprovalResponded,
ToolApprovalRespondedPart,
ToolInputAvailablePart,
ToolInputStreamingPart,
ToolOutputAvailablePart,
ToolOutputDeniedPart,
ToolOutputErrorPart,
UIMessage,
)
Expand Down Expand Up @@ -100,19 +106,36 @@ def iter_metadata_chunks(
ToolInputAvailablePart,
ToolOutputAvailablePart,
ToolOutputErrorPart,
ToolApprovalRequestedPart,
ToolApprovalRespondedPart,
ToolOutputDeniedPart,
DynamicToolInputStreamingPart,
DynamicToolInputAvailablePart,
DynamicToolOutputAvailablePart,
DynamicToolOutputErrorPart,
DynamicToolApprovalRequestedPart,
DynamicToolApprovalRespondedPart,
DynamicToolOutputDeniedPart,
)


_APPROVAL_RESPONDED_TYPES = (
ToolApprovalRespondedPart,
DynamicToolApprovalRespondedPart,
)


def iter_tool_approval_responses(
messages: list[UIMessage],
) -> Iterator[tuple[str, ToolApprovalResponded]]:
"""Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages."""
"""Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages.

Only ``approval-responded`` parts are matched. ``output-denied`` parts have
already been materialized into the message history by ``load_messages()`` and
must not be re-processed as deferred results.
"""
for msg in messages:
if msg.role == 'assistant':
for part in msg.parts:
if isinstance(part, _TOOL_PART_TYPES) and isinstance(part.approval, ToolApprovalResponded):
if isinstance(part, _APPROVAL_RESPONDED_TYPES) and isinstance(part.approval, ToolApprovalResponded):
yield part.tool_call_id, part.approval
85 changes: 84 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,51 @@ class ToolOutputErrorPart(BaseUIPart):
approval: ToolApproval | None = None


ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart
class ToolApprovalRequestedPart(BaseUIPart):
"""Tool part in approval-requested state (awaiting user decision)."""

type: Annotated[str, Field(pattern=r'^tool-')]
tool_call_id: str
state: Literal['approval-requested'] = 'approval-requested'
input: Any | None = None
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class ToolApprovalRespondedPart(BaseUIPart):
"""Tool part in approval-responded state (user approved/denied, execution pending)."""

type: Annotated[str, Field(pattern=r'^tool-')]
tool_call_id: str
state: Literal['approval-responded'] = 'approval-responded'
input: Any | None = None
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class ToolOutputDeniedPart(BaseUIPart):
"""Tool part in output-denied state (tool was denied, terminal state)."""

type: Annotated[str, Field(pattern=r'^tool-')]
tool_call_id: str
state: Literal['output-denied'] = 'output-denied'
input: Any | None = None
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


ToolUIPart = (
ToolInputStreamingPart
| ToolInputAvailablePart
| ToolOutputAvailablePart
| ToolOutputErrorPart
| ToolApprovalRequestedPart
| ToolApprovalRespondedPart
| ToolOutputDeniedPart
)
"""Union of all tool part types."""


Expand Down Expand Up @@ -245,11 +289,50 @@ class DynamicToolOutputErrorPart(BaseUIPart):
approval: ToolApproval | None = None


class DynamicToolApprovalRequestedPart(BaseUIPart):
"""Dynamic tool part in approval-requested state (awaiting user decision)."""

type: Literal['dynamic-tool'] = 'dynamic-tool'
tool_name: str
tool_call_id: str
state: Literal['approval-requested'] = 'approval-requested'
input: Any
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class DynamicToolApprovalRespondedPart(BaseUIPart):
"""Dynamic tool part in approval-responded state (user approved/denied, execution pending)."""

type: Literal['dynamic-tool'] = 'dynamic-tool'
tool_name: str
tool_call_id: str
state: Literal['approval-responded'] = 'approval-responded'
input: Any
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class DynamicToolOutputDeniedPart(BaseUIPart):
"""Dynamic tool part in output-denied state (tool was denied, terminal state)."""

type: Literal['dynamic-tool'] = 'dynamic-tool'
tool_name: str
tool_call_id: str
state: Literal['output-denied'] = 'output-denied'
input: Any
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


DynamicToolUIPart = (
DynamicToolInputStreamingPart
| DynamicToolInputAvailablePart
| DynamicToolOutputAvailablePart
| DynamicToolOutputErrorPart
| DynamicToolApprovalRequestedPart
| DynamicToolApprovalRespondedPart
| DynamicToolOutputDeniedPart
)
"""Union of all dynamic tool part types."""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ strict = true

[tool.pytest.ini_options]
testpaths = ["tests", "docs/.hooks"]
norecursedirs = ["tests/ai_sdk"]
xfail_strict = true
filterwarnings = [
"error",
Expand Down
Empty file added tests/ai_sdk/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions tests/ai_sdk/helpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/**
* Shared test utilities for AI SDK E2E integration tests.
*/

import {
AbstractChat,
DefaultChatTransport,
type ChatState,
type ChatStatus,
type UIMessage,
} from 'ai';

const url = process.env.SERVER_URL;
if (!url) {
console.error('Set SERVER_URL environment variable');
process.exit(2);
}
const SERVER_URL: string = url;

class SimpleChatState implements ChatState<UIMessage> {
status: ChatStatus = 'ready';
error: Error | undefined = undefined;
messages: UIMessage[];

constructor(messages: UIMessage[] = []) {
this.messages = messages;
}

pushMessage(message: UIMessage) {
this.messages = [...this.messages, message];
}

popMessage() {
this.messages = this.messages.slice(0, -1);
}

replaceMessage(index: number, message: UIMessage) {
this.messages = this.messages.with(index, message);
}

snapshot<T>(thing: T): T {
return structuredClone(thing);
}
}

type SendAutomaticallyWhen = (options: { messages: UIMessage[] }) => boolean;

export class TestChat extends AbstractChat<UIMessage> {
constructor(sendAutomaticallyWhen?: SendAutomaticallyWhen) {
super({
transport: new DefaultChatTransport({ api: `${SERVER_URL}/api/chat` }),
state: new SimpleChatState(),
...(sendAutomaticallyWhen ? { sendAutomaticallyWhen } : {}),
});
}
}

export function awaitRoundTrip(chat: TestChat) {
let resolve: () => void;
const promise = new Promise<void>((r) => { resolve = r; });
let captured: Error | null = null;

chat.onError = (err) => { captured = err; resolve(); };
chat.onFinish = () => resolve();

return {
done: promise.then(() => waitForStatus(chat, ['ready', 'error'])),
error() { return captured; },
};
}

export function waitForStatus(
chat: TestChat,
statuses: ChatStatus[],
timeoutMs = 10_000,
): Promise<ChatStatus> {
return new Promise((resolve, reject) => {
const start = Date.now();
(function poll() {
if (statuses.includes(chat.status)) return resolve(chat.status);
if (Date.now() - start > timeoutMs) return reject(new Error(`Timed out (status=${chat.status})`));
setTimeout(poll, 50);
})();
});
}
Loading