diff --git a/logfire/_internal/integrations/llm_providers/llm_provider.py b/logfire/_internal/integrations/llm_providers/llm_provider.py index 5f56784e2..fd579f05f 100644 --- a/logfire/_internal/integrations/llm_providers/llm_provider.py +++ b/logfire/_internal/integrations/llm_providers/llm_provider.py @@ -11,6 +11,9 @@ from ...constants import ONE_SECOND_IN_NANOSECONDS from ...utils import is_instrumentation_suppressed, log_internal_error, suppress_instrumentation +from .semconv import ( + PROVIDER_NAME, +) if TYPE_CHECKING: from ...main import Logfire, LogfireSpan @@ -28,6 +31,7 @@ def instrument_llm_provider( get_endpoint_config_fn: Callable[[Any], EndpointConfig], on_response_fn: Callable[[Any, LogfireSpan], Any], is_async_client_fn: Callable[[type[Any]], bool], + override_provider: str | None = None, ) -> AbstractContextManager[None]: """Instruments the provided `client` (or clients) with `logfire`. @@ -53,6 +57,7 @@ def instrument_llm_provider( get_endpoint_config_fn, on_response_fn, is_async_client_fn, + override_provider, ) for c in cast('Iterable[Any]', client) ] @@ -95,6 +100,9 @@ def _instrumentation_setup(*args: Any, **kwargs: Any) -> Any: return None, None, kwargs span_data['async'] = is_async + if override_provider is not None: + span_data['gen_ai.system'] = override_provider + span_data[PROVIDER_NAME] = override_provider if kwargs.get('stream') and stream_state_cls: stream_cls = kwargs['stream_cls'] diff --git a/logfire/_internal/integrations/llm_providers/openai.py b/logfire/_internal/integrations/llm_providers/openai.py index 260617abb..254c3c2b8 100644 --- a/logfire/_internal/integrations/llm_providers/openai.py +++ b/logfire/_internal/integrations/llm_providers/openai.py @@ -95,6 +95,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: span_data: dict[str, Any] = { 'request_data': json_data, 'gen_ai.request.model': json_data.get('model'), + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', OPERATION_NAME: 'chat', } @@ -117,6 +118,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: json_data.get('input'), json_data.get('instructions'), ), + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', OPERATION_NAME: 'chat', } @@ -131,6 +133,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: span_data = { 'request_data': json_data, 'gen_ai.request.model': json_data.get('model'), + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', OPERATION_NAME: 'text_completion', } @@ -144,6 +147,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: span_data = { 'request_data': json_data, 'gen_ai.request.model': json_data.get('model'), + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', OPERATION_NAME: 'embeddings', } @@ -156,6 +160,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: span_data = { 'request_data': json_data, 'gen_ai.request.model': json_data.get('model'), + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', OPERATION_NAME: 'image_generation', } @@ -168,6 +173,7 @@ def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig: span_data = { 'request_data': json_data, 'url': url, + 'gen_ai.system': 'openai', PROVIDER_NAME: 'openai', } if 'model' in json_data: @@ -266,7 +272,11 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT: on_response(response.parse(), span) # type: ignore return cast('ResponseT', response) - span.set_attribute('gen_ai.system', 'openai') + provider = (getattr(span, 'attributes', {}) or {}).get(PROVIDER_NAME, None) + if provider is None: + provider = 'openai' + span.set_attribute('gen_ai.system', provider) + span.set_attribute(PROVIDER_NAME, provider) if isinstance(response_model := getattr(response, 'model', None), str): span.set_attribute('gen_ai.response.model', response_model) @@ -282,7 +292,7 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT: ) span.set_attribute( 'operation.cost', - float(calc_price(usage_data.usage, model_ref=response_model, provider_id='openai').total_price), + float(calc_price(usage_data.usage, model_ref=response_model, provider_id=provider).total_price), ) except Exception: pass diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index f6c0c1b8e..44949d7ae 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1179,6 +1179,7 @@ def instrument_openai( | None = None, *, suppress_other_instrumentation: bool = True, + override_provider: None | str = None, ) -> AbstractContextManager[None]: """Instrument an OpenAI client so that spans are automatically created for each request. @@ -1227,6 +1228,13 @@ def instrument_openai( enabled. In reality, this means the HTTPX instrumentation, which could otherwise be called since OpenAI uses HTTPX to make HTTP requests. + override_provider: If provided, override the provider name for the instrumented client, e.g. 'openrouter'. + Do this to get: + - Correct attribution in span attributes like `gen_ai.system` + - Cost calculation in the span attribute `operation.cost`, subject to `genai_prices` package support + - Cost calculation in the Logfire UI + The default provider is 'openai'. + Returns: A context manager that will revert the instrumentation when exited. Use of this context manager is optional. @@ -1245,6 +1253,7 @@ def instrument_openai( get_endpoint_config, on_response, is_async_client, + override_provider, ) def instrument_openai_agents(self) -> None: diff --git a/tests/otel_integrations/test_openai.py b/tests/otel_integrations/test_openai.py index acb46008b..ed4b7f750 100644 --- a/tests/otel_integrations/test_openai.py +++ b/tests/otel_integrations/test_openai.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterator, Iterator from io import BytesIO from typing import Any +from unittest.mock import MagicMock import httpx import openai @@ -26,6 +27,8 @@ from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor import logfire +from logfire._internal.integrations.llm_providers.openai import on_response +from logfire._internal.integrations.llm_providers.semconv import PROVIDER_NAME from logfire._internal.utils import get_version, suppress_instrumentation from logfire.testing import TestExporter @@ -666,6 +669,7 @@ def test_sync_chat_empty_response_chunk(instrumented_client: openai.Client, expo 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'async': False, @@ -676,6 +680,7 @@ def test_sync_chat_empty_response_chunk(instrumented_client: openai.Client, expo 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -707,6 +712,7 @@ def test_sync_chat_empty_response_chunk(instrumented_client: openai.Client, expo 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'logfire.span_type': 'log', 'gen_ai.operation.name': 'chat', 'logfire.tags': ('LLM',), @@ -718,6 +724,7 @@ def test_sync_chat_empty_response_chunk(instrumented_client: openai.Client, expo 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -758,6 +765,7 @@ def test_sync_chat_empty_response_choices(instrumented_client: openai.Client, ex 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'async': False, @@ -768,6 +776,7 @@ def test_sync_chat_empty_response_choices(instrumented_client: openai.Client, ex 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -799,6 +808,7 @@ def test_sync_chat_empty_response_choices(instrumented_client: openai.Client, ex 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'logfire.span_type': 'log', 'gen_ai.operation.name': 'chat', 'logfire.tags': ('LLM',), @@ -810,6 +820,7 @@ def test_sync_chat_empty_response_choices(instrumented_client: openai.Client, ex 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -900,6 +911,7 @@ def test_sync_chat_tool_call_stream(instrumented_client: openai.Client, exporter ], }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'gen_ai.tool.definitions': [ @@ -930,6 +942,7 @@ def test_sync_chat_tool_call_stream(instrumented_client: openai.Client, exporter 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'gen_ai.tool.definitions': {}, @@ -984,6 +997,7 @@ def test_sync_chat_tool_call_stream(instrumented_client: openai.Client, exporter }, 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'async': False, 'gen_ai.operation.name': 'chat', 'gen_ai.tool.definitions': [ @@ -1043,6 +1057,7 @@ def test_sync_chat_tool_call_stream(instrumented_client: openai.Client, exporter 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'gen_ai.tool.definitions': {}, @@ -1170,6 +1185,7 @@ async def test_async_chat_tool_call_stream( ], }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'gen_ai.tool.definitions': [ @@ -1200,6 +1216,7 @@ async def test_async_chat_tool_call_stream( 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'gen_ai.tool.definitions': {}, @@ -1254,6 +1271,7 @@ async def test_async_chat_tool_call_stream( }, 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'async': True, 'gen_ai.operation.name': 'chat', 'gen_ai.tool.definitions': [ @@ -1313,6 +1331,7 @@ async def test_async_chat_tool_call_stream( 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'gen_ai.tool.definitions': {}, @@ -1391,6 +1410,7 @@ def test_sync_chat_completions_stream(instrumented_client: openai.Client, export 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'async': False, @@ -1401,6 +1421,7 @@ def test_sync_chat_completions_stream(instrumented_client: openai.Client, export 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -1435,6 +1456,7 @@ def test_sync_chat_completions_stream(instrumented_client: openai.Client, export 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'logfire.span_type': 'log', 'gen_ai.operation.name': 'chat', 'logfire.tags': ('LLM',), @@ -1458,6 +1480,7 @@ def test_sync_chat_completions_stream(instrumented_client: openai.Client, export 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -1515,6 +1538,7 @@ async def test_async_chat_completions_stream( 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4', 'gen_ai.operation.name': 'chat', 'async': True, @@ -1525,6 +1549,7 @@ async def test_async_chat_completions_stream( 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -1559,6 +1584,7 @@ async def test_async_chat_completions_stream( 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", 'gen_ai.request.model': 'gpt-4', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'logfire.span_type': 'log', 'gen_ai.operation.name': 'chat', 'logfire.tags': ('LLM',), @@ -1582,6 +1608,7 @@ async def test_async_chat_completions_stream( 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -1709,6 +1736,7 @@ def test_responses_stream(exporter: TestExporter) -> None: {'event.name': 'gen_ai.user.message', 'content': 'What is four plus five?', 'role': 'user'} ], 'request_data': {'model': 'gpt-4.1', 'stream': True}, + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-4.1', 'gen_ai.operation.name': 'chat', 'async': False, @@ -1720,6 +1748,7 @@ def test_responses_stream(exporter: TestExporter) -> None: 'gen_ai.provider.name': {}, 'events': {'type': 'array'}, 'request_data': {'type': 'object'}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -1755,6 +1784,7 @@ def test_responses_stream(exporter: TestExporter) -> None: }, ], 'gen_ai.request.model': 'gpt-4.1', + 'gen_ai.system': 'openai', 'async': False, 'gen_ai.operation.name': 'chat', 'duration': 1.0, @@ -1765,6 +1795,7 @@ def test_responses_stream(exporter: TestExporter) -> None: 'gen_ai.provider.name': {}, 'events': {'type': 'array'}, 'gen_ai.request.model': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -1804,6 +1835,7 @@ def test_completions_stream(instrumented_client: openai.Client, exporter: TestEx 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'gpt-3.5-turbo-instruct', 'gen_ai.operation.name': 'text_completion', 'async': False, @@ -1814,6 +1846,7 @@ def test_completions_stream(instrumented_client: openai.Client, exporter: TestEx 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -1845,6 +1878,7 @@ def test_completions_stream(instrumented_client: openai.Client, exporter: TestEx 'logfire.msg': "streaming response from 'gpt-3.5-turbo-instruct' took 1.00s", 'gen_ai.request.model': 'gpt-3.5-turbo-instruct', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'logfire.span_type': 'log', 'gen_ai.operation.name': 'text_completion', 'logfire.tags': ('LLM',), @@ -1856,6 +1890,7 @@ def test_completions_stream(instrumented_client: openai.Client, exporter: TestEx 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -2680,6 +2715,7 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None: 'stream': True, }, 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'gen_ai.request.model': 'google/gemini-2.5-flash', 'gen_ai.operation.name': 'chat', 'async': False, @@ -2690,6 +2726,7 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None: 'properties': { 'request_data': {'type': 'object'}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'gen_ai.request.model': {}, 'gen_ai.operation.name': {}, 'async': {}, @@ -2721,6 +2758,7 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None: }, 'gen_ai.request.model': 'google/gemini-2.5-flash', 'gen_ai.provider.name': 'openai', + 'gen_ai.system': 'openai', 'async': False, 'gen_ai.operation.name': 'chat', 'duration': 1.0, @@ -2777,6 +2815,7 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None: 'request_data': {'type': 'object'}, 'gen_ai.request.model': {}, 'gen_ai.provider.name': {}, + 'gen_ai.system': {}, 'async': {}, 'gen_ai.operation.name': {}, 'duration': {}, @@ -2803,3 +2842,231 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None: }, ] ) + + +def test_override_provider_sync(exporter: TestExporter) -> None: + """Test that override_provider sets gen_ai.system correctly for sync clients.""" + with httpx.Client(transport=MockTransport(request_handler)) as httpx_client: + openai_client = openai.Client(api_key='foobar', http_client=httpx_client) + logfire.instrument_openai(openai_client, override_provider='openrouter') + + response = openai_client.chat.completions.create( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + ) + + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.system'] == 'openrouter' + + +async def test_override_provider_async(exporter: TestExporter) -> None: + """Test that override_provider sets gen_ai.system correctly for async clients.""" + async with httpx.AsyncClient(transport=MockTransport(request_handler)) as httpx_client: + openai_client = openai.AsyncClient(api_key='foobar', http_client=httpx_client) + logfire.instrument_openai(openai_client, override_provider='custom-provider') + + response = await openai_client.chat.completions.create( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + ) + + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.system'] == 'custom-provider' + + +def test_override_provider_streaming(exporter: TestExporter) -> None: + """Test that override_provider works correctly with streaming responses.""" + with httpx.Client(transport=MockTransport(request_handler)) as httpx_client: + openai_client = openai.Client(api_key='foobar', http_client=httpx_client) + logfire.instrument_openai(openai_client, override_provider='openrouter') + + response = openai_client.chat.completions.create( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + stream=True, + ) + + # Consume the stream + for _ in response: + pass + + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + # First span is the request span + request_span = next(s for s in spans if 'Chat Completion' in s['name']) + assert request_span['attributes']['gen_ai.system'] == 'openrouter' + + +def test_default_provider_is_openai(exporter: TestExporter) -> None: + """Test that when override_provider is not set, gen_ai.system defaults to 'openai'.""" + with httpx.Client(transport=MockTransport(request_handler)) as httpx_client: + openai_client = openai.Client(api_key='foobar', http_client=httpx_client) + # Not passing override_provider, so it should default to 'openai' + logfire.instrument_openai(openai_client) + + response = openai_client.chat.completions.create( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + ) + + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert spans[0]['attributes']['gen_ai.system'] == 'openai' + + +@pytest.mark.parametrize( + ('span_attributes', 'should_set_provider'), + [ + pytest.param({}, True, id='empty_attributes_sets_openai'), + pytest.param(None, True, id='none_attributes_sets_openai'), + pytest.param({PROVIDER_NAME: 'openrouter'}, False, id='existing_provider_name_not_overwritten'), + # Edge case: gen_ai.system is set but PROVIDER_NAME is not - code looks for PROVIDER_NAME, + # so it will set both attributes to 'openai' since PROVIDER_NAME is missing + pytest.param({'gen_ai.system': 'custom_provider'}, True, id='gen_ai_system_only_still_sets_provider_name'), + ], +) +def test_on_response_gen_ai_system_behavior(span_attributes: dict[str, str] | None, should_set_provider: bool) -> None: + """Test that on_response sets gen_ai.system and gen_ai.provider.name to 'openai' only when PROVIDER_NAME not already present.""" + mock_span = MagicMock() + mock_span.attributes = span_attributes + + response = chat_completion.ChatCompletion( + id='test_id', + choices=[ + chat_completion.Choice( + finish_reason='stop', + index=0, + message=chat_completion_message.ChatCompletionMessage( + content='Test response', + role='assistant', + ), + ), + ], + created=1634720000, + model='gpt-4', + object='chat.completion', + usage=completion_usage.CompletionUsage( + completion_tokens=1, + prompt_tokens=2, + total_tokens=3, + ), + ) + + on_response(response, mock_span) + + gen_ai_system_calls = [call for call in mock_span.set_attribute.call_args_list if call[0][0] == 'gen_ai.system'] + provider_name_calls = [call for call in mock_span.set_attribute.call_args_list if call[0][0] == PROVIDER_NAME] + + if should_set_provider: + # Both gen_ai.system and PROVIDER_NAME should be set to 'openai' + assert any(call[0] == ('gen_ai.system', 'openai') for call in gen_ai_system_calls), ( + f"Expected set_attribute('gen_ai.system', 'openai') to be called, got {gen_ai_system_calls}" + ) + assert any(call[0] == (PROVIDER_NAME, 'openai') for call in provider_name_calls), ( + f"Expected set_attribute('{PROVIDER_NAME}', 'openai') to be called, got {provider_name_calls}" + ) + else: + # Neither attribute should be set when PROVIDER_NAME already exists + assert len(gen_ai_system_calls) == 0, ( + f"Expected no calls to set_attribute with 'gen_ai.system', got {gen_ai_system_calls}" + ) + assert len(provider_name_calls) == 0, ( + f"Expected no calls to set_attribute with '{PROVIDER_NAME}', got {provider_name_calls}" + ) + + +@pytest.mark.parametrize( + ('span_attributes', 'expected_provider_id'), + [ + pytest.param({}, 'openai', id='no_provider_uses_openai'), + pytest.param(None, 'openai', id='none_attributes_uses_openai'), + pytest.param({PROVIDER_NAME: 'openai'}, 'openai', id='openai_provider_uses_openai'), + pytest.param({PROVIDER_NAME: 'openrouter'}, 'openrouter', id='openrouter_provider_uses_openrouter'), + pytest.param({PROVIDER_NAME: 'azure'}, 'azure', id='azure_provider_uses_azure'), + ], +) +def test_on_response_calc_price_uses_correct_provider( + span_attributes: dict[str, str] | None, expected_provider_id: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Test that on_response uses the correct provider_id (from PROVIDER_NAME attribute) when calculating price.""" + from unittest.mock import patch + + mock_span = MagicMock() + mock_span.attributes = span_attributes + + response = chat_completion.ChatCompletion( + id='test_id', + choices=[ + chat_completion.Choice( + finish_reason='stop', + index=0, + message=chat_completion_message.ChatCompletionMessage( + content='Test response', + role='assistant', + ), + ), + ], + created=1634720000, + model='gpt-4', + object='chat.completion', + usage=completion_usage.CompletionUsage( + completion_tokens=10, + prompt_tokens=20, + total_tokens=30, + ), + ) + + with patch('genai_prices.calc_price') as mock_calc_price: + # Setup mock to return a valid price result + mock_price_result = MagicMock() + mock_price_result.total_price = 0.001 + mock_calc_price.return_value = mock_price_result + + on_response(response, mock_span) + + # Verify calc_price was called with the expected provider_id + assert mock_calc_price.called, 'calc_price should have been called' + call_kwargs = mock_calc_price.call_args + assert call_kwargs.kwargs.get('provider_id') == expected_provider_id, ( + f"Expected calc_price to be called with provider_id='{expected_provider_id}', " + f"but got provider_id='{call_kwargs.kwargs.get('provider_id')}'" + ) + + +def test_override_provider_with_client_none(exporter: TestExporter) -> None: + """Test that override_provider works when client=None (instrumenting both OpenAI and AsyncOpenAI classes).""" + with httpx.Client(transport=MockTransport(request_handler)) as httpx_client: + # Instrument both classes via the tuple path (client=None defaults to (openai.OpenAI, openai.AsyncOpenAI)) + with logfire.instrument_openai(openai_client=None, override_provider='openrouter'): + # Create a new client instance after instrumenting the class + openai_client = openai.Client(api_key='foobar', http_client=httpx_client) + + response = openai_client.chat.completions.create( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + ) + + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + # Verify that override_provider was passed through to the class instrumentation + assert spans[0]['attributes']['gen_ai.system'] == 'openrouter' + assert spans[0]['attributes']['gen_ai.provider.name'] == 'openrouter' diff --git a/tests/test_llm_provider.py b/tests/test_llm_provider.py index 4f2e6c827..e78d326b3 100644 --- a/tests/test_llm_provider.py +++ b/tests/test_llm_provider.py @@ -7,6 +7,7 @@ from typing import Any from unittest.mock import Mock +import pytest from opentelemetry import trace import logfire @@ -14,6 +15,7 @@ instrument_llm_provider, record_streaming, ) +from logfire._internal.integrations.llm_providers.semconv import PROVIDER_NAME from logfire._internal.integrations.llm_providers.types import EndpointConfig, StreamState from logfire.propagate import get_context from logfire.testing import TestExporter @@ -175,3 +177,136 @@ async def test_async_streaming_preserves_original_context(exporter: TestExporter assert streaming['context']['trace_id'] == expected_trace_id assert request['parent']['span_id'] == expected_span_id assert streaming['parent']['span_id'] == expected_span_id + + +@pytest.mark.parametrize( + ('override_provider', 'expected_gen_ai_system'), + [ + pytest.param('openrouter', 'openrouter', id='sets_custom_provider'), + pytest.param(None, None, id='none_does_not_set_attribute'), + ], +) +def test_override_provider_sync( + exporter: TestExporter, override_provider: str | None, expected_gen_ai_system: str | None +) -> None: + """Test that override_provider parameter controls the gen_ai.system and gen_ai.provider.name attributes for sync clients.""" + client = MockSyncClient() + instrument_llm_provider( + logfire=logfire.DEFAULT_LOGFIRE_INSTANCE, + client=client, + suppress_otel=False, + scope_suffix='test', + get_endpoint_config_fn=get_endpoint_config, + on_response_fn=on_response, + is_async_client_fn=is_async_client, + override_provider=override_provider, + ) + + client.request(options=MockOptions()) + + spans = exporter.exported_spans_as_dict() + request = next(s for s in spans if 'Test with' in s['name']) + + if expected_gen_ai_system is None: + # When override_provider is None, gen_ai.system should not be set by instrument_llm_provider + # (it would be set later by on_response for OpenAI) + assert 'gen_ai.system' not in request['attributes'] + assert PROVIDER_NAME not in request['attributes'] + else: + assert request['attributes']['gen_ai.system'] == expected_gen_ai_system + assert request['attributes'][PROVIDER_NAME] == expected_gen_ai_system + + +@pytest.mark.parametrize( + ('override_provider', 'expected_gen_ai_system'), + [ + pytest.param('openrouter', 'openrouter', id='sets_custom_provider'), + pytest.param(None, None, id='none_does_not_set_attribute'), + ], +) +async def test_override_provider_async( + exporter: TestExporter, override_provider: str | None, expected_gen_ai_system: str | None +) -> None: + """Test that override_provider parameter controls the gen_ai.system and gen_ai.provider.name attributes for async clients.""" + client = MockAsyncClient() + instrument_llm_provider( + logfire=logfire.DEFAULT_LOGFIRE_INSTANCE, + client=client, + suppress_otel=False, + scope_suffix='test', + get_endpoint_config_fn=get_endpoint_config, + on_response_fn=on_response, + is_async_client_fn=is_async_client, + override_provider=override_provider, + ) + + await client.request(options=MockOptions()) + + spans = exporter.exported_spans_as_dict() + request = next(s for s in spans if 'Test with' in s['name']) + + if expected_gen_ai_system is None: + assert 'gen_ai.system' not in request['attributes'] + assert PROVIDER_NAME not in request['attributes'] + else: + assert request['attributes']['gen_ai.system'] == expected_gen_ai_system + assert request['attributes'][PROVIDER_NAME] == expected_gen_ai_system + + +async def test_override_provider_with_tuple_of_client_instances(exporter: TestExporter) -> None: + """Test that override_provider is passed through when instrumenting a tuple of client instances.""" + sync_client = MockSyncClient() + async_client = MockAsyncClient() + + instrument_llm_provider( + logfire=logfire.DEFAULT_LOGFIRE_INSTANCE, + client=(sync_client, async_client), + suppress_otel=False, + scope_suffix='test', + get_endpoint_config_fn=get_endpoint_config, + on_response_fn=on_response, + is_async_client_fn=is_async_client, + override_provider='openrouter', + ) + + # Test sync client + sync_client.request(options=MockOptions()) + + # Test async client + await async_client.request(options=MockOptions()) + + spans = exporter.exported_spans_as_dict() + request_spans = [s for s in spans if 'Test with' in s['name']] + + assert len(request_spans) == 2 + + for span in request_spans: + assert span['attributes']['gen_ai.system'] == 'openrouter' + assert span['attributes'][PROVIDER_NAME] == 'openrouter' + + +def test_override_provider_with_tuple_of_client_classes(exporter: TestExporter) -> None: + """Test that override_provider is passed through when instrumenting a tuple of client classes (like instrument_openai with client=None).""" + # This simulates what happens when you call instrument_openai(client=None) + # which internally passes (openai.OpenAI, openai.AsyncOpenAI) + instrument_llm_provider( + logfire=logfire.DEFAULT_LOGFIRE_INSTANCE, + client=(MockSyncClient, MockAsyncClient), + suppress_otel=False, + scope_suffix='test', + get_endpoint_config_fn=get_endpoint_config, + on_response_fn=on_response, + is_async_client_fn=is_async_client, + override_provider='openrouter', + ) + + # Create new instances - the class itself is instrumented + sync_client = MockSyncClient() + sync_client.request(options=MockOptions()) + + spans = exporter.exported_spans_as_dict() + request_spans = [s for s in spans if 'Test with' in s['name']] + + assert len(request_spans) == 1 + assert request_spans[0]['attributes']['gen_ai.system'] == 'openrouter' + assert request_spans[0]['attributes'][PROVIDER_NAME] == 'openrouter'