Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
48e858a
added changes
yiphei Jan 7, 2026
5e431dc
added changes
yiphei Jan 7, 2026
06ce262
added changes
yiphei Jan 7, 2026
5a2b35d
fix attribute name
yiphei Jan 7, 2026
85c711f
Merge branch 'main' into support-custom-system
yiphei Jan 8, 2026
8df4eef
added changes
yiphei Jan 8, 2026
dd9702d
added changes
yiphei Jan 8, 2026
33c83d1
added changes
yiphei Jan 8, 2026
cd656b3
added changes
yiphei Jan 8, 2026
7f3d01f
added changes
yiphei Jan 8, 2026
3140257
added changes
yiphei Jan 8, 2026
0e079b2
Merge branch 'main' into support-custom-system
yiphei Jan 10, 2026
1bbd297
added changes
yiphei Jan 10, 2026
e8cbbaa
added changes
yiphei Jan 10, 2026
f33ddf3
added changes
yiphei Jan 10, 2026
36bba2f
refactored tests
yiphei Jan 10, 2026
fb2d2c0
Merge branch 'main' into support-custom-system
yiphei Jan 12, 2026
3cb7f36
added changes
yiphei Jan 12, 2026
d78496d
added changes
yiphei Jan 12, 2026
8b6085c
added changes
yiphei Jan 12, 2026
190fb88
Merge branch 'main' into support-custom-system
yiphei Jan 13, 2026
a200de6
added changes
yiphei Jan 13, 2026
fc506ab
added changes
yiphei Jan 13, 2026
ca42b29
added changes
yiphei Jan 13, 2026
85186ae
added changes
yiphei Jan 13, 2026
28e7ee0
Merge branch 'main' into support-custom-system
yiphei Jan 15, 2026
3f9255a
added changes
yiphei Jan 15, 2026
c702827
added changes
yiphei Jan 15, 2026
c0d1f2d
Merge branch 'main' into support-custom-system
yiphei Jan 19, 2026
a2aa776
Merge pull request #7 from yiphei/support-custom-system-pt2
yiphei Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions logfire/_internal/integrations/llm_providers/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from ...constants import ONE_SECOND_IN_NANOSECONDS
from ...utils import is_instrumentation_suppressed, log_internal_error, suppress_instrumentation
from .semconv import (
PROVIDER_NAME,
)

if TYPE_CHECKING:
from ...main import Logfire, LogfireSpan
Expand All @@ -28,6 +31,7 @@ def instrument_llm_provider(
get_endpoint_config_fn: Callable[[Any], EndpointConfig],
on_response_fn: Callable[[Any, LogfireSpan], Any],
is_async_client_fn: Callable[[type[Any]], bool],
override_provider: str | None = None,
) -> AbstractContextManager[None]:
"""Instruments the provided `client` (or clients) with `logfire`.

Expand All @@ -53,6 +57,7 @@ def instrument_llm_provider(
get_endpoint_config_fn,
on_response_fn,
is_async_client_fn,
override_provider,
)
for c in cast('Iterable[Any]', client)
]
Expand Down Expand Up @@ -95,6 +100,9 @@ def _instrumentation_setup(*args: Any, **kwargs: Any) -> Any:
return None, None, kwargs

span_data['async'] = is_async
if override_provider is not None:
span_data['gen_ai.system'] = override_provider
span_data[PROVIDER_NAME] = override_provider

if kwargs.get('stream') and stream_state_cls:
stream_cls = kwargs['stream_cls']
Expand Down
7 changes: 5 additions & 2 deletions logfire/_internal/integrations/llm_providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
on_response(response.parse(), span) # type: ignore
return cast('ResponseT', response)

span.set_attribute('gen_ai.system', 'openai')
provider = (getattr(span, 'attributes', {}) or {}).get('gen_ai.system', None)
if provider is None:
provider = 'openai'
span.set_attribute('gen_ai.system', provider)

if isinstance(response_model := getattr(response, 'model', None), str):
span.set_attribute('gen_ai.response.model', response_model)
Expand All @@ -282,7 +285,7 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
)
span.set_attribute(
'operation.cost',
float(calc_price(usage_data.usage, model_ref=response_model, provider_id='openai').total_price),
float(calc_price(usage_data.usage, model_ref=response_model, provider_id=provider).total_price),
)
except Exception:
pass
Expand Down
9 changes: 9 additions & 0 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ def instrument_openai(
| None = None,
*,
suppress_other_instrumentation: bool = True,
override_provider: None | str = None,
) -> AbstractContextManager[None]:
"""Instrument an OpenAI client so that spans are automatically created for each request.

Expand Down Expand Up @@ -1227,6 +1228,13 @@ def instrument_openai(
enabled. In reality, this means the HTTPX instrumentation, which could otherwise be called since
OpenAI uses HTTPX to make HTTP requests.

override_provider: If provided, override the provider name for the instrumented client, e.g. 'openrouter'.
Do this to get:
- Correct attribution in span attributes like `gen_ai.system`
- Cost calculation in the span attribute `operation.cost`, subject to `genai_prices` package support
- Cost calculation in the Logfire UI
The default provider is 'openai'.

Returns:
A context manager that will revert the instrumentation when exited.
Use of this context manager is optional.
Expand All @@ -1245,6 +1253,7 @@ def instrument_openai(
get_endpoint_config,
on_response,
is_async_client,
override_provider,
)

def instrument_openai_agents(self) -> None:
Expand Down
219 changes: 219 additions & 0 deletions tests/otel_integrations/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import AsyncIterator, Iterator
from io import BytesIO
from typing import Any
from unittest.mock import MagicMock

import httpx
import openai
Expand All @@ -26,6 +27,7 @@
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor

import logfire
from logfire._internal.integrations.llm_providers.openai import on_response
from logfire._internal.utils import get_version, suppress_instrumentation
from logfire.testing import TestExporter

Expand Down Expand Up @@ -2803,3 +2805,220 @@ def test_openrouter_streaming_reasoning(exporter: TestExporter) -> None:
},
]
)


def test_override_provider_sync(exporter: TestExporter) -> None:
"""Test that override_provider sets gen_ai.system correctly for sync clients."""
with httpx.Client(transport=MockTransport(request_handler)) as httpx_client:
openai_client = openai.Client(api_key='foobar', http_client=httpx_client)
logfire.instrument_openai(openai_client, override_provider='openrouter')

response = openai_client.chat.completions.create(
model='gpt-4',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is four plus five?'},
],
)

assert response.choices[0].message.content == 'Nine'
spans = exporter.exported_spans_as_dict(parse_json_attributes=True)
assert len(spans) == 1
assert spans[0]['attributes']['gen_ai.system'] == 'openrouter'


async def test_override_provider_async(exporter: TestExporter) -> None:
"""Test that override_provider sets gen_ai.system correctly for async clients."""
async with httpx.AsyncClient(transport=MockTransport(request_handler)) as httpx_client:
openai_client = openai.AsyncClient(api_key='foobar', http_client=httpx_client)
logfire.instrument_openai(openai_client, override_provider='custom-provider')

response = await openai_client.chat.completions.create(
model='gpt-4',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is four plus five?'},
],
)

assert response.choices[0].message.content == 'Nine'
spans = exporter.exported_spans_as_dict(parse_json_attributes=True)
assert len(spans) == 1
assert spans[0]['attributes']['gen_ai.system'] == 'custom-provider'


def test_override_provider_streaming(exporter: TestExporter) -> None:
"""Test that override_provider works correctly with streaming responses."""
with httpx.Client(transport=MockTransport(request_handler)) as httpx_client:
openai_client = openai.Client(api_key='foobar', http_client=httpx_client)
logfire.instrument_openai(openai_client, override_provider='openrouter')

response = openai_client.chat.completions.create(
model='gpt-4',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is four plus five?'},
],
stream=True,
)

# Consume the stream
for _ in response:
pass

spans = exporter.exported_spans_as_dict(parse_json_attributes=True)
# First span is the request span
request_span = next(s for s in spans if 'Chat Completion' in s['name'])
assert request_span['attributes']['gen_ai.system'] == 'openrouter'


def test_default_provider_is_openai(exporter: TestExporter) -> None:
"""Test that when override_provider is not set, gen_ai.system defaults to 'openai'."""
with httpx.Client(transport=MockTransport(request_handler)) as httpx_client:
openai_client = openai.Client(api_key='foobar', http_client=httpx_client)
# Not passing override_provider, so it should default to 'openai'
logfire.instrument_openai(openai_client)

response = openai_client.chat.completions.create(
model='gpt-4',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is four plus five?'},
],
)

assert response.choices[0].message.content == 'Nine'
spans = exporter.exported_spans_as_dict(parse_json_attributes=True)
assert spans[0]['attributes']['gen_ai.system'] == 'openai'


@pytest.mark.parametrize(
('span_attributes', 'should_set_gen_ai_system'),
[
pytest.param({}, True, id='empty_attributes_sets_openai'),
pytest.param(None, True, id='none_attributes_sets_openai'),
pytest.param({'gen_ai.system': 'openrouter'}, False, id='existing_value_not_overwritten'),
],
)
def test_on_response_gen_ai_system_behavior(
span_attributes: dict[str, str] | None, should_set_gen_ai_system: bool
) -> None:
"""Test that on_response sets gen_ai.system to 'openai' only when not already present."""
mock_span = MagicMock()
mock_span.attributes = span_attributes

response = chat_completion.ChatCompletion(
id='test_id',
choices=[
chat_completion.Choice(
finish_reason='stop',
index=0,
message=chat_completion_message.ChatCompletionMessage(
content='Test response',
role='assistant',
),
),
],
created=1634720000,
model='gpt-4',
object='chat.completion',
usage=completion_usage.CompletionUsage(
completion_tokens=1,
prompt_tokens=2,
total_tokens=3,
),
)

on_response(response, mock_span)

gen_ai_system_calls = [call for call in mock_span.set_attribute.call_args_list if call[0][0] == 'gen_ai.system']
if should_set_gen_ai_system:
assert any(call[0] == ('gen_ai.system', 'openai') for call in gen_ai_system_calls), (
f"Expected set_attribute('gen_ai.system', 'openai') to be called, got {gen_ai_system_calls}"
)
else:
assert len(gen_ai_system_calls) == 0, (
f"Expected no calls to set_attribute with 'gen_ai.system', got {gen_ai_system_calls}"
)


@pytest.mark.parametrize(
('span_attributes', 'expected_provider_id'),
[
pytest.param({}, 'openai', id='no_system_uses_openai'),
pytest.param(None, 'openai', id='none_attributes_uses_openai'),
pytest.param({'gen_ai.system': 'openai'}, 'openai', id='openai_system_uses_openai'),
pytest.param({'gen_ai.system': 'openrouter'}, 'openrouter', id='openrouter_system_uses_openrouter'),
pytest.param({'gen_ai.system': 'azure'}, 'azure', id='azure_system_uses_azure'),
],
)
def test_on_response_calc_price_uses_correct_provider(
span_attributes: dict[str, str] | None, expected_provider_id: str, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that on_response uses the correct provider_id when calculating price."""
from unittest.mock import patch

mock_span = MagicMock()
mock_span.attributes = span_attributes

response = chat_completion.ChatCompletion(
id='test_id',
choices=[
chat_completion.Choice(
finish_reason='stop',
index=0,
message=chat_completion_message.ChatCompletionMessage(
content='Test response',
role='assistant',
),
),
],
created=1634720000,
model='gpt-4',
object='chat.completion',
usage=completion_usage.CompletionUsage(
completion_tokens=10,
prompt_tokens=20,
total_tokens=30,
),
)

with patch('genai_prices.calc_price') as mock_calc_price:
# Setup mock to return a valid price result
mock_price_result = MagicMock()
mock_price_result.total_price = 0.001
mock_calc_price.return_value = mock_price_result

on_response(response, mock_span)

# Verify calc_price was called with the expected provider_id
assert mock_calc_price.called, 'calc_price should have been called'
call_kwargs = mock_calc_price.call_args
assert call_kwargs.kwargs.get('provider_id') == expected_provider_id, (
f"Expected calc_price to be called with provider_id='{expected_provider_id}', "
f"but got provider_id='{call_kwargs.kwargs.get('provider_id')}'"
)


def test_override_provider_with_client_none(exporter: TestExporter) -> None:
"""Test that override_provider works when client=None (instrumenting both OpenAI and AsyncOpenAI classes)."""
with httpx.Client(transport=MockTransport(request_handler)) as httpx_client:
# Instrument both classes via the tuple path (client=None defaults to (openai.OpenAI, openai.AsyncOpenAI))
with logfire.instrument_openai(openai_client=None, override_provider='openrouter'):
# Create a new client instance after instrumenting the class
openai_client = openai.Client(api_key='foobar', http_client=httpx_client)

response = openai_client.chat.completions.create(
model='gpt-4',
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'What is four plus five?'},
],
)

assert response.choices[0].message.content == 'Nine'
spans = exporter.exported_spans_as_dict(parse_json_attributes=True)
assert len(spans) == 1
# Verify that override_provider was passed through to the class instrumentation
assert spans[0]['attributes']['gen_ai.system'] == 'openrouter'
assert spans[0]['attributes']['gen_ai.provider.name'] == 'openrouter'
Loading
Loading