diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index bcfb834555..9cffd5d648 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -53,8 +53,10 @@ SourceUrlUIPart, StepStartUIPart, TextUIPart, + ToolApprovalResponded, ToolInputAvailablePart, ToolOutputAvailablePart, + ToolOutputDeniedPart, ToolOutputErrorPart, ToolUIPart, UIMessage, @@ -295,7 +297,9 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # # The call and return metadata are combined in the output part. # So we extract and return them to the respective parts call_meta = return_meta = {} - has_tool_output = isinstance(part, (ToolOutputAvailablePart, ToolOutputErrorPart)) + has_tool_output = isinstance( + part, (ToolOutputAvailablePart, ToolOutputErrorPart, ToolOutputDeniedPart) + ) if has_tool_output: call_meta, return_meta = cls._load_builtin_tool_meta(provider_meta) @@ -315,8 +319,10 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # output: Any | None = None if isinstance(part, ToolOutputAvailablePart): output = part.output - elif isinstance(part, ToolOutputErrorPart): # pragma: no branch + elif isinstance(part, ToolOutputErrorPart): output = {'error_text': part.error_text, 'is_error': True} + elif isinstance(part, ToolOutputDeniedPart): # pragma: no branch + output = _denial_reason(part) builder.add( BuiltinToolReturnPart( tool_name=tool_name, @@ -348,6 +354,14 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # tool_name=tool_name, tool_call_id=tool_call_id, content=part.error_text ) ) + elif part.state == 'output-denied': + builder.add( + ToolReturnPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=_denial_reason(part), + ) + ) elif isinstance(part, DataUIPart): # pragma: no cover # Contains custom data that shouldn't be sent to the model pass @@ -635,6 +649,13 @@ def _convert_user_prompt_part(part: UserPromptPart) -> list[UIMessagePart]: return ui_parts +def _denial_reason(part: ToolUIPart | DynamicToolUIPart) -> str: + """Extract the denial reason from a tool part's approval, or return a default message.""" + if isinstance(part.approval, ToolApprovalResponded) and part.approval.reason: + return part.approval.reason + return 'Tool call was denied.' + + def _extract_metadata_ui_parts(tool_result: ToolReturnPart) -> list[UIMessagePart]: """Convert data-carrying chunks from tool metadata into UIMessageParts. diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py index afd172cce4..35b9235aa4 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_utils.py @@ -5,14 +5,20 @@ from pydantic_ai.messages import ProviderDetailsDelta, ToolReturnPart from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolApprovalRequestedPart, + DynamicToolApprovalRespondedPart, DynamicToolInputAvailablePart, DynamicToolInputStreamingPart, DynamicToolOutputAvailablePart, + DynamicToolOutputDeniedPart, DynamicToolOutputErrorPart, + ToolApprovalRequestedPart, ToolApprovalResponded, + ToolApprovalRespondedPart, ToolInputAvailablePart, ToolInputStreamingPart, ToolOutputAvailablePart, + ToolOutputDeniedPart, ToolOutputErrorPart, UIMessage, ) @@ -100,19 +106,36 @@ def iter_metadata_chunks( ToolInputAvailablePart, ToolOutputAvailablePart, ToolOutputErrorPart, + ToolApprovalRequestedPart, + ToolApprovalRespondedPart, + ToolOutputDeniedPart, DynamicToolInputStreamingPart, DynamicToolInputAvailablePart, DynamicToolOutputAvailablePart, DynamicToolOutputErrorPart, + DynamicToolApprovalRequestedPart, + DynamicToolApprovalRespondedPart, + DynamicToolOutputDeniedPart, +) + + +_APPROVAL_RESPONDED_TYPES = ( + ToolApprovalRespondedPart, + DynamicToolApprovalRespondedPart, ) def iter_tool_approval_responses( messages: list[UIMessage], ) -> Iterator[tuple[str, ToolApprovalResponded]]: - """Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages.""" + """Yield `(tool_call_id, approval)` for each responded tool approval in assistant messages. + + Only ``approval-responded`` parts are matched. ``output-denied`` parts have + already been materialized into the message history by ``load_messages()`` and + must not be re-processed as deferred results. + """ for msg in messages: if msg.role == 'assistant': for part in msg.parts: - if isinstance(part, _TOOL_PART_TYPES) and isinstance(part.approval, ToolApprovalResponded): + if isinstance(part, _APPROVAL_RESPONDED_TYPES) and isinstance(part.approval, ToolApprovalResponded): yield part.tool_call_id, part.approval diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py index a49c84224e..804fb468c2 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py @@ -189,7 +189,51 @@ class ToolOutputErrorPart(BaseUIPart): approval: ToolApproval | None = None -ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart +class ToolApprovalRequestedPart(BaseUIPart): + """Tool part in approval-requested state (awaiting user decision).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['approval-requested'] = 'approval-requested' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class ToolApprovalRespondedPart(BaseUIPart): + """Tool part in approval-responded state (user approved/denied, execution pending).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['approval-responded'] = 'approval-responded' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class ToolOutputDeniedPart(BaseUIPart): + """Tool part in output-denied state (tool was denied, terminal state).""" + + type: Annotated[str, Field(pattern=r'^tool-')] + tool_call_id: str + state: Literal['output-denied'] = 'output-denied' + input: Any | None = None + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +ToolUIPart = ( + ToolInputStreamingPart + | ToolInputAvailablePart + | ToolOutputAvailablePart + | ToolOutputErrorPart + | ToolApprovalRequestedPart + | ToolApprovalRespondedPart + | ToolOutputDeniedPart +) """Union of all tool part types.""" @@ -245,11 +289,50 @@ class DynamicToolOutputErrorPart(BaseUIPart): approval: ToolApproval | None = None +class DynamicToolApprovalRequestedPart(BaseUIPart): + """Dynamic tool part in approval-requested state (awaiting user decision).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['approval-requested'] = 'approval-requested' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class DynamicToolApprovalRespondedPart(BaseUIPart): + """Dynamic tool part in approval-responded state (user approved/denied, execution pending).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['approval-responded'] = 'approval-responded' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + +class DynamicToolOutputDeniedPart(BaseUIPart): + """Dynamic tool part in output-denied state (tool was denied, terminal state).""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['output-denied'] = 'output-denied' + input: Any + call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None + + DynamicToolUIPart = ( DynamicToolInputStreamingPart | DynamicToolInputAvailablePart | DynamicToolOutputAvailablePart | DynamicToolOutputErrorPart + | DynamicToolApprovalRequestedPart + | DynamicToolApprovalRespondedPart + | DynamicToolOutputDeniedPart ) """Union of all dynamic tool part types.""" diff --git a/pyproject.toml b/pyproject.toml index 33a8254966..f51da4c2de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -233,6 +233,7 @@ strict = true [tool.pytest.ini_options] testpaths = ["tests", "docs/.hooks"] +norecursedirs = ["tests/ai_sdk"] xfail_strict = true filterwarnings = [ "error", diff --git a/tests/ai_sdk/__init__.py b/tests/ai_sdk/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ai_sdk/helpers.ts b/tests/ai_sdk/helpers.ts new file mode 100644 index 0000000000..96667df9b2 --- /dev/null +++ b/tests/ai_sdk/helpers.ts @@ -0,0 +1,85 @@ +/** + * Shared test utilities for AI SDK E2E integration tests. + */ + +import { + AbstractChat, + DefaultChatTransport, + type ChatState, + type ChatStatus, + type UIMessage, +} from 'ai'; + +const url = process.env.SERVER_URL; +if (!url) { + console.error('Set SERVER_URL environment variable'); + process.exit(2); +} +const SERVER_URL: string = url; + +class SimpleChatState implements ChatState { + status: ChatStatus = 'ready'; + error: Error | undefined = undefined; + messages: UIMessage[]; + + constructor(messages: UIMessage[] = []) { + this.messages = messages; + } + + pushMessage(message: UIMessage) { + this.messages = [...this.messages, message]; + } + + popMessage() { + this.messages = this.messages.slice(0, -1); + } + + replaceMessage(index: number, message: UIMessage) { + this.messages = this.messages.with(index, message); + } + + snapshot(thing: T): T { + return structuredClone(thing); + } +} + +type SendAutomaticallyWhen = (options: { messages: UIMessage[] }) => boolean; + +export class TestChat extends AbstractChat { + constructor(sendAutomaticallyWhen?: SendAutomaticallyWhen) { + super({ + transport: new DefaultChatTransport({ api: `${SERVER_URL}/api/chat` }), + state: new SimpleChatState(), + ...(sendAutomaticallyWhen ? { sendAutomaticallyWhen } : {}), + }); + } +} + +export function awaitRoundTrip(chat: TestChat) { + let resolve: () => void; + const promise = new Promise((r) => { resolve = r; }); + let captured: Error | null = null; + + chat.onError = (err) => { captured = err; resolve(); }; + chat.onFinish = () => resolve(); + + return { + done: promise.then(() => waitForStatus(chat, ['ready', 'error'])), + error() { return captured; }, + }; +} + +export function waitForStatus( + chat: TestChat, + statuses: ChatStatus[], + timeoutMs = 10_000, +): Promise { + return new Promise((resolve, reject) => { + const start = Date.now(); + (function poll() { + if (statuses.includes(chat.status)) return resolve(chat.status); + if (Date.now() - start > timeoutMs) return reject(new Error(`Timed out (status=${chat.status})`)); + setTimeout(poll, 50); + })(); + }); +} diff --git a/tests/ai_sdk/package-lock.json b/tests/ai_sdk/package-lock.json new file mode 100644 index 0000000000..d71da69876 --- /dev/null +++ b/tests/ai_sdk/package-lock.json @@ -0,0 +1,146 @@ +{ + "name": "pydantic-ai-ai-sdk-tests", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "pydantic-ai-ai-sdk-tests", + "dependencies": { + "ai": "^6.0.57" + }, + "devDependencies": { + "@types/node": "^22.0.0" + } + }, + "node_modules/@ai-sdk/gateway": { + "version": "3.0.53", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.53.tgz", + "integrity": "sha512-QT3FEoNARMRlk8JJVR7L98exiK9C8AGfrEJVbRxBT1yIXKs/N19o/+PsjTRVsARgDJNcy9JbJp1FspKucEat0Q==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@vercel/oidc": "3.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz", + "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@opentelemetry/api": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", + "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", + "license": "Apache-2.0", + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.19.11", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.11.tgz", + "integrity": "sha512-BH7YwL6rA93ReqeQS1c4bsPpcfOmJasG+Fkr6Y59q83f9M1WcBRHR2vM+P9eOisYRcN3ujQoiZY8uk5W+1WL8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.21.0" + } + }, + "node_modules/@vercel/oidc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz", + "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, + "node_modules/ai": { + "version": "6.0.95", + "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.95.tgz", + "integrity": "sha512-10emBqMtiAqyR0xyVdVw2/mZgKWKSWagUHdkE54BNd9b5KSW9ez1BtAmO78nyOtLbOiDfLHUWSUtcJ276/fiKA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "3.0.53", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/json-schema": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", + "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", + "license": "(AFL-2.1 OR BSD-3-Clause)" + }, + "node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/zod": { + "version": "4.3.6", + "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", + "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + } + } +} diff --git a/tests/ai_sdk/package.json b/tests/ai_sdk/package.json new file mode 100644 index 0000000000..b0012bd172 --- /dev/null +++ b/tests/ai_sdk/package.json @@ -0,0 +1,11 @@ +{ + "name": "pydantic-ai-ai-sdk-tests", + "private": true, + "type": "module", + "dependencies": { + "ai": "^6.0.57" + }, + "devDependencies": { + "@types/node": "^22.0.0" + } +} diff --git a/tests/ai_sdk/server.py b/tests/ai_sdk/server.py new file mode 100644 index 0000000000..864b4e3ba2 --- /dev/null +++ b/tests/ai_sdk/server.py @@ -0,0 +1,117 @@ +"""Starlette server for AI SDK E2E integration testing. + +Takes an agent name as a CLI argument and serves it at /api/chat. +""" + +from __future__ import annotations + +import sys +from collections.abc import AsyncIterator +from typing import Any + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from pydantic_ai import Agent +from pydantic_ai.messages import ModelMessage, ModelRequest, ToolReturnPart +from pydantic_ai.models.function import ( + AgentInfo, + DeltaThinkingCalls, + DeltaThinkingPart, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, +) +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import DeferredToolRequests +from pydantic_ai.ui.vercel_ai import VercelAIAdapter + +# --- Agents --- + +text_agent = Agent(model=TestModel(custom_output_text='Hello, world!'), output_type=str) + + +async def _thinking_stream(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaThinkingCalls | str]: + yield {0: DeltaThinkingPart(content='Let me think about this... The answer is clear.')} + yield 'The answer is 42.' + + +thinking_agent = Agent(model=FunctionModel(stream_function=_thinking_stream), output_type=str) + +tool_agent: Agent[None, str] = Agent(model=TestModel(), output_type=str) + + +@tool_agent.tool_plain +def get_weather(city: str) -> str: + return f'Sunny in {city}' + + +def _count_denials(messages: list[ModelMessage]) -> int: + return sum( + 1 + for msg in messages + if isinstance(msg, ModelRequest) + for part in msg.parts + if isinstance(part, ToolReturnPart) and 'denied' in str(part.content).lower() + ) + + +async def _approval_stream(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls | str]: + denials = _count_denials(messages) + if denials == 0: + yield {0: DeltaToolCall(name='delete_file', json_args='{"path": "test.txt"}')} + elif denials == 1: + yield {0: DeltaToolCall(name='delete_file', json_args='{"path": "retry.txt"}')} + else: + yield 'Done.' + + +approval_agent: Agent[None, str | DeferredToolRequests] = Agent( + model=FunctionModel(stream_function=_approval_stream), + output_type=[str, DeferredToolRequests], +) + + +@approval_agent.tool_plain(requires_approval=True) +def delete_file(path: str) -> str: + return f'Deleted {path}' + + +multi_tool_agent: Agent[None, str] = Agent(model=TestModel(), output_type=str) + + +@multi_tool_agent.tool_plain +def lookup_weather(city: str) -> str: + return f'Rainy in {city}' + + +@multi_tool_agent.tool_plain +def lookup_time(timezone: str) -> str: + return f'12:00 {timezone}' + + +AGENTS = { + 'text': text_agent, + 'thinking': thinking_agent, + 'tool': tool_agent, + 'tool_approval': approval_agent, + 'multi_tool': multi_tool_agent, +} + + +def create_app(agent: Agent[None, Any]) -> Starlette: + async def chat_endpoint(request: Request) -> Response: + return await VercelAIAdapter.dispatch_request(request, agent=agent, sdk_version=6) + + return Starlette(routes=[Route('/api/chat', chat_endpoint, methods=['POST'])]) + + +if __name__ == '__main__': + import uvicorn + + agent_name = sys.argv[1] + port = int(sys.argv[2]) if len(sys.argv) > 2 else 8000 + agent = AGENTS[agent_name] + uvicorn.run(create_app(agent), host='127.0.0.1', port=port, log_level='warning') diff --git a/tests/ai_sdk/test_ai_sdk.py b/tests/ai_sdk/test_ai_sdk.py new file mode 100644 index 0000000000..68ef4e237c --- /dev/null +++ b/tests/ai_sdk/test_ai_sdk.py @@ -0,0 +1,106 @@ +"""Pytest orchestration for AI SDK <-> Pydantic AI integration tests. + +Starts a real HTTP server per test, runs TypeScript tests against it, and fails +if the tests fail. Requires node >= 22.6 (built-in TypeScript strip). +""" + +from __future__ import annotations + +import os +import shutil +import socket +import subprocess +import sys +import tempfile +import time +from collections.abc import Iterator +from pathlib import Path + +import pytest + +SDK_DIR = Path(__file__).parent +REPO_ROOT = SDK_DIR.parents[1] +SERVER_MODULE = 'tests.ai_sdk.server' +STARTUP_TIMEOUT = 10.0 +STARTUP_POLL = 0.25 + +pytestmark = pytest.mark.skipif(not shutil.which('node'), reason='node not installed') + + +@pytest.fixture(scope='module') +def _npm_install() -> None: + if not (SDK_DIR / 'node_modules').is_dir(): + subprocess.run(['npm', 'install'], cwd=SDK_DIR, check=True, capture_output=True) + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + + +def _wait_for_server(port: int, timeout: float = STARTUP_TIMEOUT) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with socket.create_connection(('127.0.0.1', port), timeout=0.5): + return + except OSError: + time.sleep(STARTUP_POLL) + raise TimeoutError(f'Server on port {port} did not start within {timeout}s') + + +@pytest.fixture +def server_url(request: pytest.FixtureRequest) -> Iterator[str]: + agent_name: str = request.param + port = _free_port() + log = tempfile.TemporaryFile() + proc = subprocess.Popen( + [sys.executable, '-m', SERVER_MODULE, agent_name, str(port)], + cwd=REPO_ROOT, + stdout=log, + stderr=log, + ) + try: + _wait_for_server(port) + yield f'http://127.0.0.1:{port}' + finally: + proc.terminate() + proc.wait(timeout=5) + log.seek(0) + output = log.read().decode(errors='replace') + log.close() + if output: + print(f'\n--- server log ---\n{output}--- end server log ---') + + +TEST_FILES = sorted(SDK_DIR.glob('test_*.ts')) + + +def test_agents_match_test_files() -> None: + from tests.ai_sdk.server import AGENTS + + agent_names = set(AGENTS.keys()) + test_names = {f.stem.removeprefix('test_') for f in TEST_FILES} + assert agent_names == test_names + + +@pytest.mark.parametrize( + ('test_file', 'server_url'), + [(f, f.stem.removeprefix('test_')) for f in TEST_FILES], + ids=[f.name for f in TEST_FILES], + indirect=['server_url'], +) +def test_ai_sdk(_npm_install: None, server_url: str, test_file: Path) -> None: + result = subprocess.run( + ['node', '--test', str(test_file)], + env={**os.environ, 'SERVER_URL': server_url}, + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode != 0: + pytest.fail( + f'node --test {test_file.name} exited {result.returncode}\n\nstdout:\n{result.stdout}\n\nstderr:\n{result.stderr}' + ) diff --git a/tests/ai_sdk/test_multi_tool.ts b/tests/ai_sdk/test_multi_tool.ts new file mode 100644 index 0000000000..f5a250da51 --- /dev/null +++ b/tests/ai_sdk/test_multi_tool.ts @@ -0,0 +1,28 @@ +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isTextUIPart, isToolUIPart } from 'ai'; +import { TestChat, awaitRoundTrip } from './helpers.ts'; + +describe('multiple tool calls', () => { + it('executes multiple tools and returns text', async () => { + const chat = new TestChat(); + const trip = awaitRoundTrip(chat); + chat.sendMessage({ text: 'Weather and time please' }); + await trip.done; + assert.equal(trip.error(), null, 'request should succeed'); + + const assistant = chat.messages.find((m) => m.role === 'assistant'); + assert.ok(assistant, 'should have an assistant message'); + + const toolParts = assistant.parts.filter(isToolUIPart); + assert.equal(toolParts.length, 2, 'should have two tool parts'); + + assert.ok( + toolParts.every((p) => p.state === 'output-available'), + 'all tools should have output-available state', + ); + + const textParts = assistant.parts.filter(isTextUIPart); + assert.ok(textParts.length > 0, 'should have at least one text part after tool execution'); + }); +}); diff --git a/tests/ai_sdk/test_text.ts b/tests/ai_sdk/test_text.ts new file mode 100644 index 0000000000..90dca4c076 --- /dev/null +++ b/tests/ai_sdk/test_text.ts @@ -0,0 +1,23 @@ +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isTextUIPart } from 'ai'; +import { TestChat, awaitRoundTrip } from './helpers.ts'; + +describe('text streaming', () => { + it('receives streamed text parts', async () => { + const chat = new TestChat(); + const trip = awaitRoundTrip(chat); + chat.sendMessage({ text: 'Say hello' }); + await trip.done; + assert.equal(trip.error(), null, 'request should succeed'); + + const assistant = chat.messages.find((m) => m.role === 'assistant'); + assert.ok(assistant, 'should have an assistant message'); + + const textParts = assistant.parts.filter(isTextUIPart); + assert.ok(textParts.length > 0, 'should have at least one text part'); + + const fullText = textParts.map((p) => p.text).join(''); + assert.ok(fullText.includes('Hello, world!'), `expected greeting, got: ${fullText}`); + }); +}); diff --git a/tests/ai_sdk/test_thinking.ts b/tests/ai_sdk/test_thinking.ts new file mode 100644 index 0000000000..2c6878744e --- /dev/null +++ b/tests/ai_sdk/test_thinking.ts @@ -0,0 +1,29 @@ +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isReasoningUIPart, isTextUIPart } from 'ai'; +import { TestChat, awaitRoundTrip } from './helpers.ts'; + +describe('thinking', () => { + it('receives reasoning and text parts', async () => { + const chat = new TestChat(); + const trip = awaitRoundTrip(chat); + chat.sendMessage({ text: 'What is the answer?' }); + await trip.done; + assert.equal(trip.error(), null, 'request should succeed'); + + const assistant = chat.messages.find((m) => m.role === 'assistant'); + assert.ok(assistant, 'should have an assistant message'); + + const reasoningParts = assistant.parts.filter(isReasoningUIPart); + assert.ok(reasoningParts.length > 0, 'should have at least one reasoning part'); + + const reasoningText = reasoningParts.map((p) => p.text).join(''); + assert.ok(reasoningText.includes('think'), `expected thinking content, got: ${reasoningText}`); + + const textParts = assistant.parts.filter(isTextUIPart); + assert.ok(textParts.length > 0, 'should have at least one text part'); + + const fullText = textParts.map((p) => p.text).join(''); + assert.ok(fullText.includes('42'), `expected answer text, got: ${fullText}`); + }); +}); diff --git a/tests/ai_sdk/test_tool.ts b/tests/ai_sdk/test_tool.ts new file mode 100644 index 0000000000..dc3fcb7e6b --- /dev/null +++ b/tests/ai_sdk/test_tool.ts @@ -0,0 +1,27 @@ +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isTextUIPart, isToolUIPart } from 'ai'; +import { TestChat, awaitRoundTrip } from './helpers.ts'; + +describe('tool call without approval', () => { + it('executes tool and returns text result', async () => { + const chat = new TestChat(); + const trip = awaitRoundTrip(chat); + chat.sendMessage({ text: 'What is the weather?' }); + await trip.done; + assert.equal(trip.error(), null, 'request should succeed'); + + const assistant = chat.messages.find((m) => m.role === 'assistant'); + assert.ok(assistant, 'should have an assistant message'); + + const toolParts = assistant.parts.filter(isToolUIPart); + assert.ok(toolParts.length > 0, 'should have at least one tool part'); + assert.ok( + toolParts.some((p) => p.state === 'output-available'), + 'should have a tool with output-available state', + ); + + const textParts = assistant.parts.filter(isTextUIPart); + assert.ok(textParts.length > 0, 'should have at least one text part after tool execution'); + }); +}); diff --git a/tests/ai_sdk/test_tool_approval.ts b/tests/ai_sdk/test_tool_approval.ts new file mode 100644 index 0000000000..b0e479f1a1 --- /dev/null +++ b/tests/ai_sdk/test_tool_approval.ts @@ -0,0 +1,115 @@ +import { describe, it } from 'node:test'; +import assert from 'node:assert/strict'; +import { isToolUIPart, lastAssistantMessageIsCompleteWithApprovalResponses } from 'ai'; +import { TestChat, awaitRoundTrip } from './helpers.ts'; + +function getLatestApprovalId(chat: TestChat): string { + const assistant = [...chat.messages].reverse().find((m) => m.role === 'assistant'); + assert.ok(assistant, 'should have an assistant message'); + + const toolParts = assistant.parts.filter(isToolUIPart); + const approvalPart = [...toolParts].reverse().find((p) => p.state === 'approval-requested'); + assert.ok(approvalPart, 'no tool part with state=approval-requested'); + return approvalPart.approval.id; +} + +async function sendAndGetApprovalId(chat: TestChat): Promise { + const trip = awaitRoundTrip(chat); + chat.sendMessage({ text: 'Delete test.txt' }); + await trip.done; + assert.equal(trip.error(), null, 'initial request should succeed'); + return getLatestApprovalId(chat); +} + +function createApprovalChat() { + return new TestChat(lastAssistantMessageIsCompleteWithApprovalResponses); +} + +describe('tool approval', () => { + it('returns approval-requested state on initial request', async () => { + const chat = createApprovalChat(); + await sendAndGetApprovalId(chat); + }); + + it('completes round-trip after approval', async () => { + const chat = createApprovalChat(); + const approvalId = await sendAndGetApprovalId(chat); + + const resubmit = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: approvalId, approved: true }); + await resubmit.done; + assert.equal(resubmit.error(), null, 'resubmit after approval should succeed'); + }); + + it('retries with new tool call after denial', async () => { + const chat = createApprovalChat(); + const approvalId = await sendAndGetApprovalId(chat); + + const resubmit = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: approvalId, approved: false }); + await resubmit.done; + assert.equal(resubmit.error(), null, 'resubmit after denial should succeed'); + + // Server retries after first denial — should get a new approval-requested + const retryApprovalId = getLatestApprovalId(chat); + assert.notEqual(retryApprovalId, approvalId, 'retry should produce a new approval id'); + }); + + it('completes after deny then approve', async () => { + const chat = createApprovalChat(); + const firstId = await sendAndGetApprovalId(chat); + + // Deny the first tool call — server retries + const retry = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: firstId, approved: false }); + await retry.done; + assert.equal(retry.error(), null, 'retry after denial should succeed'); + + // Approve the retried tool call + const secondId = getLatestApprovalId(chat); + const approve = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: secondId, approved: true }); + await approve.done; + assert.equal(approve.error(), null, 'approve after retry should succeed'); + }); + + it('completes after deny then deny', async () => { + const chat = createApprovalChat(); + const firstId = await sendAndGetApprovalId(chat); + + // Deny the first tool call — server retries + const retry = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: firstId, approved: false }); + await retry.done; + assert.equal(retry.error(), null, 'retry after first denial should succeed'); + + // Deny the retried tool call — server gives up and returns text + const secondId = getLatestApprovalId(chat); + const final = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: secondId, approved: false }); + await final.done; + assert.equal(final.error(), null, 'second denial should complete with text'); + }); + + it('completes after denial with reason', async () => { + const chat = createApprovalChat(); + const approvalId = await sendAndGetApprovalId(chat); + + // Deny with reason — server retries + const retry = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ + id: approvalId, + approved: false, + reason: 'Not allowed by policy', + }); + await retry.done; + assert.equal(retry.error(), null, 'resubmit after denial with reason should succeed'); + + // Deny again — server gives up + const secondId = getLatestApprovalId(chat); + const final = awaitRoundTrip(chat); + await chat.addToolApprovalResponse({ id: secondId, approved: false }); + await final.done; + assert.equal(final.error(), null, 'second denial should complete with text'); + }); +}); diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 8b4448bcdb..46030a8b6a 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -58,6 +58,7 @@ from pydantic_ai.ui.vercel_ai import VercelAIAdapter, VercelAIEventStream from pydantic_ai.ui.vercel_ai._utils import dump_provider_metadata, load_provider_metadata from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolApprovalRespondedPart, DynamicToolInputAvailablePart, DynamicToolOutputAvailablePart, FileUIPart, @@ -2424,13 +2425,13 @@ def delete_file(path: str) -> str: role='assistant', parts=[ TextUIPart(text='I will delete the file for you.'), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_approved', input={'path': 'approved.txt'}, approval=ToolApprovalResponded(id='approval-456', approved=True), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_1', input={'path': 'test.txt'}, @@ -2556,7 +2557,7 @@ def some_tool(x: str) -> str: input={'x': 'no_approval'}, approval=None, ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='some_tool', tool_call_id='approved_tool', input={'x': 'approved'}, @@ -2622,7 +2623,7 @@ def delete_file(path: str) -> str: id='assistant-1', role='assistant', parts=[ - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_1', input={'path': 'important.txt'}, @@ -2630,13 +2631,13 @@ def delete_file(path: str) -> str: id='denial-id', approved=False, reason='User cancelled the deletion' ), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_2', input={'path': 'temp.txt'}, approval=ToolApprovalResponded(id='denial-no-reason', approved=False), ), - DynamicToolInputAvailablePart( + DynamicToolApprovalRespondedPart( tool_name='delete_file', tool_call_id='delete_3', input={'path': 'ok.txt'}, @@ -2656,6 +2657,42 @@ def delete_file(path: str) -> str: assert approvals['delete_3'] is True +async def test_tool_approval_ignores_output_denied_parts(): + """Test that output-denied parts are not yielded by iter_tool_approval_responses. + + When a denied tool is retried, the assistant message accumulates both an + output-denied part (terminal, already materialized by load_messages) and an + approval-responded part (pending, needs deferred handling). Only the latter + should be extracted. + """ + from pydantic_ai.ui.vercel_ai._utils import iter_tool_approval_responses + from pydantic_ai.ui.vercel_ai.request_types import DynamicToolOutputDeniedPart + + messages = [ + UIMessage( + id='assistant-1', + role='assistant', + parts=[ + DynamicToolOutputDeniedPart( + tool_name='delete_file', + tool_call_id='tool_A', + input={'path': 'first.txt'}, + approval=ToolApprovalResponded(id='deny-A', approved=False, reason='Not allowed'), + ), + DynamicToolApprovalRespondedPart( + tool_name='delete_file', + tool_call_id='tool_B', + input={'path': 'second.txt'}, + approval=ToolApprovalResponded(id='deny-B', approved=False), + ), + ], + ) + ] + + results = dict(iter_tool_approval_responses(messages)) + assert results == {'tool_B': ToolApprovalResponded(id='deny-B', approved=False)} + + async def test_run_stream_with_deferred_tool_results_no_model_response(): """Test that run_stream errors when deferred_tool_results is passed without a ModelResponse in history.""" agent = Agent(model=TestModel()) @@ -5760,3 +5797,91 @@ async def event_generator(): e for e in events_v6 if isinstance(e, dict) and e.get('type') == 'tool-input-start' ) assert 'providerMetadata' in tool_input_start_v6 + + +@pytest.mark.parametrize( + ('reason', 'expected_content'), + [ + pytest.param('Too dangerous', 'Too dangerous', id='explicit-reason'), + pytest.param(None, 'Tool call was denied.', id='default-reason'), + ], +) +async def test_adapter_load_messages_output_denied(reason: str | None, expected_content: str): + from pydantic_ai.ui.vercel_ai.request_types import DynamicToolOutputDeniedPart + + ui_messages = [ + UIMessage( + id='msg1', + role='assistant', + parts=[ + DynamicToolOutputDeniedPart( + tool_name='delete_file', + tool_call_id='tc_denied', + input={'path': 'important.txt'}, + approval=ToolApprovalResponded(id='deny-1', approved=False, reason=reason), + ), + ], + ) + ] + + messages = VercelAIAdapter.load_messages(ui_messages) + assert messages == snapshot( + [ + ModelResponse( + parts=[ToolCallPart(tool_name='delete_file', args={'path': 'important.txt'}, tool_call_id='tc_denied')], + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='delete_file', + content=expected_content, + tool_call_id='tc_denied', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_adapter_load_messages_output_denied_builtin_tool(): + from pydantic_ai.ui.vercel_ai.request_types import ToolOutputDeniedPart + + ui_messages = [ + UIMessage( + id='msg1', + role='assistant', + parts=[ + ToolOutputDeniedPart( + type='tool-web_search', + tool_call_id='tc_builtin_denied', + input={'query': 'secret data'}, + provider_executed=True, + approval=ToolApprovalResponded(id='deny-2', approved=False, reason='Blocked by policy'), + ), + ], + ) + ] + + messages = VercelAIAdapter.load_messages(ui_messages) + assert messages == snapshot( + [ + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + args={'query': 'secret data'}, + tool_call_id='tc_builtin_denied', + ), + BuiltinToolReturnPart( + tool_name='web_search', + content='Blocked by policy', + tool_call_id='tc_builtin_denied', + timestamp=IsDatetime(), + ), + ], + timestamp=IsDatetime(), + ) + ] + )