-
Notifications
You must be signed in to change notification settings - Fork 212
Support overriding default model provider in instrument_openai #1609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
48e858a
5e431dc
06ce262
5a2b35d
85c711f
8df4eef
dd9702d
33c83d1
cd656b3
7f3d01f
3140257
0e079b2
1bbd297
e8cbbaa
f33ddf3
36bba2f
fb2d2c0
3cb7f36
d78496d
8b6085c
190fb88
a200de6
fc506ab
ca42b29
85186ae
28e7ee0
3f9255a
c702827
c0d1f2d
a2aa776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ def instrument_llm_provider( | |
| get_endpoint_config_fn: Callable[[Any], EndpointConfig], | ||
| on_response_fn: Callable[[Any, LogfireSpan], Any], | ||
| is_async_client_fn: Callable[[type[Any]], bool], | ||
| model_provider: str | None = None, | ||
| ) -> AbstractContextManager[None]: | ||
| """Instruments the provided `client` (or clients) with `logfire`. | ||
|
|
||
|
|
@@ -93,6 +94,8 @@ def _instrumentation_setup(*args: Any, **kwargs: Any) -> Any: | |
| return None, None, kwargs | ||
|
|
||
| span_data['async'] = is_async | ||
| if model_provider is not None: | ||
| span_data['overridden_model_provider'] = model_provider | ||
|
||
|
|
||
| if kwargs.get('stream') and stream_state_cls: | ||
| stream_cls = kwargs['stream_cls'] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,7 +183,9 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT: | |
| on_response(response.parse(), span) # type: ignore | ||
| return cast('ResponseT', response) | ||
|
|
||
| span.set_attribute('gen_ai.system', 'openai') | ||
| model_provider: str = cast(str, (span.attributes or {}).get('overridden_model_provider', "openai")) | ||
|
||
|
|
||
| span.set_attribute('gen_ai.system', model_provider) | ||
|
|
||
| if isinstance(response_model := getattr(response, 'model', None), str): | ||
| span.set_attribute('gen_ai.response.model', response_model) | ||
|
|
@@ -194,12 +196,12 @@ def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT: | |
| response_data = response.model_dump() # type: ignore | ||
| usage_data = extract_usage( | ||
| response_data, | ||
| provider_id='openai', | ||
| provider_id=model_provider, | ||
|
||
| api_flavor='responses' if isinstance(response, Response) else 'chat', | ||
| ) | ||
| span.set_attribute( | ||
| 'operation.cost', | ||
| float(calc_price(usage_data.usage, model_ref=response_model, provider_id='openai').total_price), | ||
| float(calc_price(usage_data.usage, model_ref=response_model, provider_id=model_provider).total_price), | ||
| ) | ||
| except Exception: | ||
| pass | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just call it
provider