diff --git a/sentry_sdk/integrations/pydantic_ai/patches/graph_nodes.py b/sentry_sdk/integrations/pydantic_ai/patches/graph_nodes.py index afb10395f4..6e638505a6 100644 --- a/sentry_sdk/integrations/pydantic_ai/patches/graph_nodes.py +++ b/sentry_sdk/integrations/pydantic_ai/patches/graph_nodes.py @@ -57,6 +57,11 @@ def _patch_graph_nodes() -> None: @wraps(original_model_request_run) async def wrapped_model_request_run(self: "Any", ctx: "Any") -> "Any": + did_stream = getattr(self, "_did_stream", None) + cached_result = getattr(self, "_result", None) + if did_stream or cached_result is not None: + return await original_model_request_run(self, ctx) + messages, model, model_settings = _extract_span_data(self, ctx) with ai_client_span(messages, None, model, model_settings) as span: @@ -83,6 +88,11 @@ def create_wrapped_stream( @asynccontextmanager @wraps(original_stream_method) async def wrapped_model_request_stream(self: "Any", ctx: "Any") -> "Any": + did_stream = getattr(self, "_did_stream", None) + if did_stream: + async with original_stream_method(self, ctx) as stream: + yield stream + messages, model, model_settings = _extract_span_data(self, ctx) # Create chat span for streaming request diff --git a/tests/integrations/pydantic_ai/test_pydantic_ai.py b/tests/integrations/pydantic_ai/test_pydantic_ai.py index 15627a705a..6776a45039 100644 --- a/tests/integrations/pydantic_ai/test_pydantic_ai.py +++ b/tests/integrations/pydantic_ai/test_pydantic_ai.py @@ -75,7 +75,7 @@ async def test_agent_run_async(sentry_init, capture_events, test_agent): # Find child span types (invoke_agent is the transaction, not a child span) chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 # Check chat span chat_span = chat_spans[0] @@ -158,7 +158,7 @@ def test_agent_run_sync(sentry_init, capture_events, test_agent): # Find span types chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 # Verify streaming flag is False for sync for chat_span in chat_spans: @@ -192,7 +192,7 @@ async def test_agent_run_stream(sentry_init, capture_events, test_agent): # Find chat spans chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 # Verify streaming flag is True for streaming for chat_span in chat_spans: @@ -231,9 +231,8 @@ async def test_agent_run_stream_events(sentry_init, capture_events, test_agent): # Find chat spans spans = transaction["spans"] chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 - # run_stream_events uses run() internally, so streaming should be False for chat_span in chat_spans: assert chat_span["data"]["gen_ai.response.streaming"] is False @@ -269,7 +268,7 @@ def add_numbers(a: int, b: int) -> int: tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"] # Should have tool spans - assert len(tool_spans) >= 1 + assert len(tool_spans) == 1 # Check tool span tool_span = tool_spans[0] @@ -502,7 +501,7 @@ async def test_model_settings(sentry_init, capture_events, test_agent_with_setti # Find chat span chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 chat_span = chat_spans[0] # Check that model settings are captured @@ -548,7 +547,7 @@ async def test_system_prompt_attribute( # The transaction IS the invoke_agent span, check for messages in chat spans instead chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 chat_span = chat_spans[0] @@ -587,7 +586,7 @@ async def test_error_handling(sentry_init, capture_events): await agent.run("Hello") # At minimum, we should have a transaction - assert len(events) >= 1 + assert len(events) == 1 transaction = [e for e in events if e.get("type") == "transaction"][0] assert transaction["transaction"] == "invoke_agent test_error" # Transaction should complete successfully (status key may not exist if no error) @@ -681,7 +680,7 @@ async def run_agent(input_text): assert transaction["type"] == "transaction" assert transaction["transaction"] == "invoke_agent test_agent" # Each should have its own spans - assert len(transaction["spans"]) >= 1 + assert len(transaction["spans"]) == 1 @pytest.mark.asyncio @@ -721,7 +720,7 @@ async def test_message_history(sentry_init, capture_events): await agent.run("What is my name?", message_history=history) # We should have 2 transactions - assert len(events) >= 2 + assert len(events) == 2 # Check the second transaction has the full history second_transaction = events[1] @@ -755,7 +754,7 @@ async def test_gen_ai_system(sentry_init, capture_events, test_agent): # Find chat span chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 chat_span = chat_spans[0] # gen_ai.system should be set from the model (TestModel -> 'test') @@ -812,7 +811,7 @@ async def test_include_prompts_true(sentry_init, capture_events, test_agent): chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] # Verify that messages are captured in chat spans - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 for chat_span in chat_spans: assert "gen_ai.request.messages" in chat_span["data"] @@ -1242,7 +1241,7 @@ async def test_invoke_agent_with_instructions( # The transaction IS the invoke_agent span, check for messages in chat spans instead chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 chat_span = chat_spans[0] @@ -1366,7 +1365,7 @@ async def test_usage_data_partial(sentry_init, capture_events): spans = transaction["spans"] chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 # Check that usage data fields exist (they may or may not be set depending on TestModel) chat_span = chat_spans[0] @@ -1461,7 +1460,7 @@ def calc_tool(value: int) -> int: chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] # At least one chat span should exist - assert len(chat_spans) >= 1 + assert len(chat_spans) == 2 # Check if tool calls are captured in response for chat_span in chat_spans: @@ -1509,7 +1508,7 @@ async def test_message_formatting_with_different_parts(sentry_init, capture_even chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] # Should have chat spans - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 # Check that messages are captured chat_span = chat_spans[0] @@ -1781,7 +1780,7 @@ def test_tool(x: int) -> int: chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"] # Should have chat spans - assert len(chat_spans) >= 1 + assert len(chat_spans) == 2 @pytest.mark.asyncio @@ -2762,7 +2761,7 @@ async def test_binary_content_in_agent_run(sentry_init, capture_events): (transaction,) = events chat_spans = [s for s in transaction["spans"] if s["op"] == "gen_ai.chat"] - assert len(chat_spans) >= 1 + assert len(chat_spans) == 1 chat_span = chat_spans[0] if "gen_ai.request.messages" in chat_span["data"]: