diff --git a/CLAUDE.md b/CLAUDE.md index dc5d683ed..25147b8a9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,6 +75,10 @@ Some tests are decorated with `@pytest.mark.vcr()` and use `pytest-recording` to The `logfire-api` package is a no-op shim that libraries can depend on to avoid hard dependencies on Logfire itself. It provides minimal 'implementations' in `logfire-api/logfire_api/__init__.py`, which needs to be kept up to date with the public API of the `logfire` module, especially if `test_logfire_api.py` starts failing. The rest is just `.pyi` stubs which should be ignored and are autogenerated when needed during release. +# Integrations + +`instrument_*` methods in `main.py` should use proper type annotations for their client parameters, not `Any`. The third-party types are imported under `if TYPE_CHECKING:` at the top of `main.py` (aliased if needed to avoid name collisions). ImportError handling and client resolution (e.g., building the list of all client classes when `None` is passed) should live in the integration module under `logfire/_internal/integrations/`, not in `main.py`. See `instrument_openai`, `instrument_anthropic`, and `instrument_azure_ai_inference` for examples. + # Misc Use `git push origin HEAD` to push, not just `git push`, so that it pushes to the current branch without needing to set upstream explicitly. diff --git a/docs/integrations/llms/azure-ai-inference.md b/docs/integrations/llms/azure-ai-inference.md new file mode 100644 index 000000000..5279a0f14 --- /dev/null +++ b/docs/integrations/llms/azure-ai-inference.md @@ -0,0 +1,138 @@ +--- +title: Pydantic Logfire Azure AI Inference Integration +description: "Instrument calls to Azure AI Inference with logfire.instrument_azure_ai_inference(). Track chat completions, embeddings, streaming responses, and token usage." +integration: logfire +--- +# Azure AI Inference + +**Logfire** supports instrumenting calls to [Azure AI Inference](https://pypi.org/project/azure-ai-inference/) with the [`logfire.instrument_azure_ai_inference()`][logfire.Logfire.instrument_azure_ai_inference] method. + +```python hl_lines="11-12" skip-run="true" skip-reason="external-connection" +from azure.ai.inference import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.complete( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Please write me a limerick about Python logging.'}, + ], +) +print(response.choices[0].message.content) +``` + +With that you get: + +* a span around the call which records duration and captures any exceptions that might occur +* Human-readable display of the conversation with the agent +* details of the response, including the number of tokens used + +## Installation + +Install Logfire with the `azure-ai-inference` extra: + +{{ install_logfire(extras=['azure-ai-inference']) }} + +## Methods covered + +The following methods are covered: + +- [`ChatCompletionsClient.complete`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.chatcompletionsclient) - with and without `stream=True` +- [`EmbeddingsClient.embed`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.embeddingsclient) + +All methods are covered with both sync (`azure.ai.inference`) and async (`azure.ai.inference.aio`) clients. + +## Streaming Responses + +When instrumenting streaming responses, Logfire creates two spans - one around the initial request and one around the streamed response. + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Write Python to show a tree of files.'}], + stream=True, +) +for chunk in response: + if chunk.choices: + delta = chunk.choices[0].delta + if delta and delta.content: + print(delta.content, end='', flush=True) +``` + +## Embeddings + +You can also instrument the `EmbeddingsClient`: + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference import EmbeddingsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = EmbeddingsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) + +response = client.embed( + model='text-embedding-ada-002', + input=['Hello world'], +) +print(len(response.data[0].embedding)) +``` + +## Async Support + +Async clients from `azure.ai.inference.aio` are fully supported: + +```python skip-run="true" skip-reason="external-connection" +from azure.ai.inference.aio import ChatCompletionsClient +from azure.core.credentials import AzureKeyCredential + +import logfire + +client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), +) + +logfire.configure() +logfire.instrument_azure_ai_inference(client) +``` + +## Global Instrumentation + +If no client is passed, all `ChatCompletionsClient` and `EmbeddingsClient` classes (both sync and async) are instrumented: + +```python skip-run="true" skip-reason="external-connection" +import logfire + +logfire.configure() +logfire.instrument_azure_ai_inference() +``` diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index ab431b5ea..d83abca09 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -187,6 +187,9 @@ def instrument_print(self, *args, **kwargs) -> ContextManager[None]: def instrument_openai_agents(self, *args, **kwargs) -> None: ... + def instrument_azure_ai_inference(self, *args, **kwargs) -> ContextManager[None]: + return nullcontext() + def instrument_google_genai(self, *args, **kwargs) -> None: ... def instrument_litellm(self, *args, **kwargs) -> None: ... @@ -230,6 +233,7 @@ def shutdown(self, *args, **kwargs) -> None: ... instrument_openai = DEFAULT_LOGFIRE_INSTANCE.instrument_openai instrument_openai_agents = DEFAULT_LOGFIRE_INSTANCE.instrument_openai_agents instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic + instrument_azure_ai_inference = DEFAULT_LOGFIRE_INSTANCE.instrument_azure_ai_inference instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm instrument_dspy = DEFAULT_LOGFIRE_INSTANCE.instrument_dspy diff --git a/logfire/__init__.py b/logfire/__init__.py index 8c0d01e02..2170e96f7 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -45,6 +45,7 @@ instrument_openai = DEFAULT_LOGFIRE_INSTANCE.instrument_openai instrument_openai_agents = DEFAULT_LOGFIRE_INSTANCE.instrument_openai_agents instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic +instrument_azure_ai_inference = DEFAULT_LOGFIRE_INSTANCE.instrument_azure_ai_inference instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm instrument_dspy = DEFAULT_LOGFIRE_INSTANCE.instrument_dspy @@ -152,6 +153,7 @@ def loguru_handler() -> Any: 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', + 'instrument_azure_ai_inference', 'instrument_google_genai', 'instrument_litellm', 'instrument_dspy', diff --git a/logfire/_internal/integrations/llm_providers/azure_ai_inference.py b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py new file mode 100644 index 000000000..57954ddb8 --- /dev/null +++ b/logfire/_internal/integrations/llm_providers/azure_ai_inference.py @@ -0,0 +1,695 @@ +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractContextManager, ExitStack, contextmanager, nullcontext +from typing import TYPE_CHECKING, Any, cast + +from opentelemetry.trace import SpanKind + +from logfire import attach_context, get_context + +from ...constants import ONE_SECOND_IN_NANOSECONDS +from ...utils import handle_internal_errors, is_instrumentation_suppressed, log_internal_error, suppress_instrumentation +from .semconv import ( + INPUT_MESSAGES, + INPUT_TOKENS, + OPERATION_NAME, + OUTPUT_MESSAGES, + OUTPUT_TOKENS, + PROVIDER_NAME, + REQUEST_FREQUENCY_PENALTY, + REQUEST_MAX_TOKENS, + REQUEST_MODEL, + REQUEST_PRESENCE_PENALTY, + REQUEST_SEED, + REQUEST_STOP_SEQUENCES, + REQUEST_TEMPERATURE, + REQUEST_TOP_P, + RESPONSE_FINISH_REASONS, + RESPONSE_ID, + RESPONSE_MODEL, + SYSTEM_INSTRUCTIONS, + TOOL_DEFINITIONS, + BlobPart, + ChatMessage, + InputMessages, + MessagePart, + OutputMessage, + OutputMessages, + Role, + SystemInstructions, + TextPart, + ToolCallPart, + ToolCallResponsePart, + UriPart, +) + +if TYPE_CHECKING: + from ...main import Logfire, LogfireSpan + +__all__ = ('instrument_azure_ai_inference',) + +AZURE_PROVIDER = 'azure.ai.inference' + +CHAT_MSG_TEMPLATE = 'Chat completion with {request_data[model]!r}' +CHAT_MSG_TEMPLATE_NO_MODEL = 'Chat completion' +EMBED_MSG_TEMPLATE = 'Embeddings with {request_data[model]!r}' +EMBED_MSG_TEMPLATE_NO_MODEL = 'Embeddings' +STREAM_MSG_TEMPLATE = 'streaming response from {request_data[model]!r} took {duration:.2f}s' +STREAM_MSG_TEMPLATE_NO_MODEL = 'streaming response took {duration:.2f}s' + + +# --- Main instrumentation entry point --- + + +def instrument_azure_ai_inference( + logfire_instance: Logfire, + client: Any, + suppress_other_instrumentation: bool, +) -> AbstractContextManager[None]: + """Instrument Azure AI Inference clients.""" + if client is None: + try: + from azure.ai.inference import ChatCompletionsClient, EmbeddingsClient + except ImportError: # pragma: no cover + raise RuntimeError( + 'The `logfire.instrument_azure_ai_inference()` method ' + 'requires the `azure-ai-inference` package.\n' + 'You can install this with:\n' + " pip install 'logfire[azure-ai-inference]'" + ) + + clients: list[Any] = [ChatCompletionsClient, EmbeddingsClient] + try: + from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncChatCompletionsClient, + EmbeddingsClient as AsyncEmbeddingsClient, + ) + + clients.extend([AsyncChatCompletionsClient, AsyncEmbeddingsClient]) + except ImportError: # pragma: no cover + pass + client = clients + + if isinstance(client, (tuple, list)): + context_managers = [ + instrument_azure_ai_inference(logfire_instance, c, suppress_other_instrumentation) for c in client + ] + + @contextmanager + def uninstrument_all() -> Iterator[None]: + with ExitStack() as stack: + for cm in context_managers: + stack.enter_context(cm) + yield + + return uninstrument_all() + + if getattr(client, '_is_instrumented_by_logfire', False): + return nullcontext() + + client_cls = client if isinstance(client, type) else type(client) + is_async = _is_async_client(client_cls) + client_type = _get_client_type(client_cls) + + if client_type is None: # pragma: no cover + return nullcontext() + + logfire_llm = logfire_instance.with_settings(custom_scope_suffix='azure_ai_inference', tags=['LLM']) + client._is_instrumented_by_logfire = True + + if client_type == 'chat': + method_name = 'complete' + original = client.complete + client._original_logfire_method = original + client.complete = _make_instrumented_complete(original, logfire_llm, suppress_other_instrumentation, is_async) + else: + method_name = 'embed' + original = client.embed + client._original_logfire_method = original + client.embed = _make_instrumented_embed(original, logfire_llm, suppress_other_instrumentation, is_async) + + @contextmanager + def uninstrument() -> Iterator[None]: + try: + yield + finally: + setattr(client, method_name, client._original_logfire_method) + del client._original_logfire_method + client._is_instrumented_by_logfire = False + + return uninstrument() + + +# --- Client type detection --- + + +def _is_async_client(client_cls: type[Any]) -> bool: + return 'aio' in client_cls.__module__ + + +def _get_client_type(client_cls: type[Any]) -> str | None: + name = client_cls.__name__ + if 'ChatCompletions' in name: + return 'chat' + if 'Embeddings' in name: + return 'embeddings' + return None # pragma: no cover + + +# --- Instrumented method factories --- + + +def _make_instrumented_complete( + original: Any, + logfire_llm: Logfire, + suppress: bool, + is_async: bool, +) -> Any: + if is_async: + + async def instrumented_complete(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): # pragma: no cover + return await original(*args, **kwargs) + try: + span_data = _build_chat_span_data(args, kwargs) + except Exception: # pragma: no cover + log_internal_error() + return await original(*args, **kwargs) + + is_streaming = kwargs.get('stream', False) + original_context = get_context() + msg = CHAT_MSG_TEMPLATE if span_data['request_data']['model'] else CHAT_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: + if suppress: + with suppress_instrumentation(): + response = await original(*args, **kwargs) + else: + response = await original(*args, **kwargs) + + if is_streaming: + return _AsyncStreamWrapper(response, logfire_llm, span_data, original_context) + _on_chat_response(response, span, span_data) + return response + + return instrumented_complete + else: + + def instrumented_complete_sync(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): # pragma: no cover + return original(*args, **kwargs) + try: + span_data = _build_chat_span_data(args, kwargs) + except Exception: # pragma: no cover + log_internal_error() + return original(*args, **kwargs) + + is_streaming = kwargs.get('stream', False) + original_context = get_context() + msg = CHAT_MSG_TEMPLATE if span_data['request_data']['model'] else CHAT_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: + if suppress: + with suppress_instrumentation(): + response = original(*args, **kwargs) + else: + response = original(*args, **kwargs) + + if is_streaming: + return _SyncStreamWrapper(response, logfire_llm, span_data, original_context) + _on_chat_response(response, span, span_data) + return response + + return instrumented_complete_sync + + +def _make_instrumented_embed( + original: Any, + logfire_llm: Logfire, + suppress: bool, + is_async: bool, +) -> Any: + if is_async: + + async def instrumented_embed(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): # pragma: no cover + return await original(*args, **kwargs) + try: + span_data = _build_embed_span_data(args, kwargs) + except Exception: # pragma: no cover + log_internal_error() + return await original(*args, **kwargs) + + msg = EMBED_MSG_TEMPLATE if span_data['request_data']['model'] else EMBED_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: + if suppress: + with suppress_instrumentation(): + response = await original(*args, **kwargs) + else: + response = await original(*args, **kwargs) + _on_embed_response(response, span, span_data) + return response + + return instrumented_embed + else: + + def instrumented_embed_sync(*args: Any, **kwargs: Any) -> Any: + if is_instrumentation_suppressed(): # pragma: no cover + return original(*args, **kwargs) + try: + span_data = _build_embed_span_data(args, kwargs) + except Exception: # pragma: no cover + log_internal_error() + return original(*args, **kwargs) + + msg = EMBED_MSG_TEMPLATE if span_data['request_data']['model'] else EMBED_MSG_TEMPLATE_NO_MODEL + + with logfire_llm.span(msg, _span_kind=SpanKind.CLIENT, **span_data) as span: + if suppress: + with suppress_instrumentation(): + response = original(*args, **kwargs) + else: + response = original(*args, **kwargs) + _on_embed_response(response, span, span_data) + return response + + return instrumented_embed_sync + + +# --- Span data builders --- + + +def _build_chat_span_data( + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> dict[str, Any]: + params = _extract_params(args, kwargs) + messages = params.get('messages', []) + model = params.get('model') + + span_data: dict[str, Any] = { + 'request_data': {'model': model}, + PROVIDER_NAME: AZURE_PROVIDER, + OPERATION_NAME: 'chat', + } + if model: + span_data[REQUEST_MODEL] = model + + _extract_request_parameters(params, span_data) + + if messages: + input_messages, system_instructions = convert_messages_to_semconv(messages) + span_data[INPUT_MESSAGES] = input_messages + if system_instructions: + span_data[SYSTEM_INSTRUCTIONS] = system_instructions + + return span_data + + +def _build_embed_span_data( + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> dict[str, Any]: + params = _extract_params(args, kwargs) + model = params.get('model') + + span_data: dict[str, Any] = { + 'request_data': {'model': model}, + PROVIDER_NAME: AZURE_PROVIDER, + OPERATION_NAME: 'embeddings', + } + if model: + span_data[REQUEST_MODEL] = model + + return span_data + + +def _extract_params(args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + """Extract parameters from method call, handling both body and keyword styles.""" + if 'body' in kwargs and isinstance(kwargs['body'], dict): + return kwargs['body'] + for arg in args: + if isinstance(arg, dict) and ('messages' in arg or 'input' in arg): + return arg + return kwargs + + +def _extract_request_parameters(params: dict[str, Any], span_data: dict[str, Any]) -> None: + if (max_tokens := params.get('max_tokens')) is not None: + span_data[REQUEST_MAX_TOKENS] = max_tokens + if (temperature := params.get('temperature')) is not None: + span_data[REQUEST_TEMPERATURE] = temperature + if (top_p := params.get('top_p')) is not None: + span_data[REQUEST_TOP_P] = top_p + if (frequency_penalty := params.get('frequency_penalty')) is not None: + span_data[REQUEST_FREQUENCY_PENALTY] = frequency_penalty + if (presence_penalty := params.get('presence_penalty')) is not None: + span_data[REQUEST_PRESENCE_PENALTY] = presence_penalty + if (seed := params.get('seed')) is not None: + span_data[REQUEST_SEED] = seed + if (stop := params.get('stop')) is not None: + span_data[REQUEST_STOP_SEQUENCES] = json.dumps(stop) + if (tools := params.get('tools')) is not None: + span_data[TOOL_DEFINITIONS] = json.dumps([t if isinstance(t, dict) else t.as_dict() for t in tools]) + + +# --- Response processors --- + + +def _backfill_model(response: Any, span: LogfireSpan, span_data: dict[str, Any], operation: str = 'chat') -> None: + """If the request model was None, backfill it from the response model.""" + model = getattr(response, 'model', None) + if not model: + return + request_data = span_data.get('request_data') + if not isinstance(request_data, dict) or request_data.get('model') is not None: + return + request_data['model'] = model + span.set_attribute('request_data', request_data) + span.set_attribute(REQUEST_MODEL, model) + if operation == 'chat': + span.message = f'Chat completion with {model!r}' + else: + span.message = f'Embeddings with {model!r}' + + +@handle_internal_errors +def _on_chat_response(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: + _backfill_model(response, span, span_data) + choices = getattr(response, 'choices', []) + usage = getattr(response, 'usage', None) + + output_messages = convert_response_to_semconv(response) + if output_messages: + span.set_attribute(OUTPUT_MESSAGES, output_messages) + + model = getattr(response, 'model', None) + if model: + span.set_attribute(RESPONSE_MODEL, model) + + response_id = getattr(response, 'id', None) + if response_id: + span.set_attribute(RESPONSE_ID, response_id) + + if usage: + prompt_tokens = getattr(usage, 'prompt_tokens', None) + if prompt_tokens is not None: + span.set_attribute(INPUT_TOKENS, prompt_tokens) + completion_tokens = getattr(usage, 'completion_tokens', None) + if completion_tokens is not None: + span.set_attribute(OUTPUT_TOKENS, completion_tokens) + + finish_reasons = [str(c.finish_reason) for c in choices if getattr(c, 'finish_reason', None)] + if finish_reasons: + span.set_attribute(RESPONSE_FINISH_REASONS, finish_reasons) + + +@handle_internal_errors +def _on_embed_response(response: Any, span: LogfireSpan, span_data: dict[str, Any]) -> None: + _backfill_model(response, span, span_data, operation='embeddings') + usage = getattr(response, 'usage', None) + + model = getattr(response, 'model', None) + if model: + span.set_attribute(RESPONSE_MODEL, model) + + response_id = getattr(response, 'id', None) + if response_id: + span.set_attribute(RESPONSE_ID, response_id) + + if usage: + prompt_tokens = getattr(usage, 'prompt_tokens', None) + if prompt_tokens is not None: + span.set_attribute(INPUT_TOKENS, prompt_tokens) + + +# --- Message conversion --- + + +def convert_messages_to_semconv(messages: list[Any]) -> tuple[InputMessages, SystemInstructions]: + """Convert Azure AI Inference messages to OTel GenAI semconv format.""" + input_messages: InputMessages = [] + system_instructions: SystemInstructions = [] + + for msg in messages: + msg_dict = _msg_to_dict(msg) + role: str = msg_dict.get('role', 'user') + content = msg_dict.get('content') + + if role in ('system', 'developer'): + if isinstance(content, str): + system_instructions.append(TextPart(type='text', content=content)) + continue + + if role == 'tool': + tool_call_id = msg_dict.get('tool_call_id', '') + input_messages.append( + ChatMessage( + role='tool', + parts=[ + ToolCallResponsePart( + type='tool_call_response', + id=tool_call_id, + response=content if isinstance(content, str) else str(content) if content else '', + ) + ], + ) + ) + continue + + parts: list[MessagePart] = [] + if isinstance(content, str) and content: + parts.append(TextPart(type='text', content=content)) + elif isinstance(content, list): + for item in content: + parts.append(_convert_content_item(item)) + + tool_calls = msg_dict.get('tool_calls') + if tool_calls: + for tc in tool_calls: + tc_dict = tc if isinstance(tc, dict) else (tc.as_dict() if hasattr(tc, 'as_dict') else {}) + func = tc_dict.get('function', {}) + parts.append( + ToolCallPart( + type='tool_call', + id=tc_dict.get('id', ''), + name=func.get('name', ''), + arguments=func.get('arguments'), + ) + ) + + chat_role: Role = cast('Role', role if role in ('user', 'assistant') else 'user') + input_messages.append(ChatMessage(role=chat_role, parts=parts)) + + return input_messages, system_instructions + + +def _msg_to_dict(msg: Any) -> dict[str, Any]: + """Convert an Azure message object or dict to a plain dict.""" + if isinstance(msg, dict): + return msg + if hasattr(msg, 'as_dict'): + return msg.as_dict() + return {} # pragma: no cover + + +def _convert_content_item(item: Any) -> MessagePart: + """Convert a content item (text, image, audio) to semconv format.""" + if isinstance(item, str): + return TextPart(type='text', content=item) + + item_dict = item if isinstance(item, dict) else (item.as_dict() if hasattr(item, 'as_dict') else {}) + item_type = item_dict.get('type', 'text') + + if item_type == 'text': + return TextPart(type='text', content=item_dict.get('text', '')) + elif item_type == 'image_url': + image_url = item_dict.get('image_url', {}) + return UriPart(type='uri', uri=image_url.get('url', ''), modality='image') + elif item_type == 'input_audio': + audio = item_dict.get('input_audio', {}) + return BlobPart( + type='blob', + content=audio.get('data', ''), + media_type=f'audio/{audio.get("format", "wav")}', + modality='audio', + ) + else: # pragma: no cover + return cast('MessagePart', item_dict) + + +def convert_response_to_semconv(response: Any) -> OutputMessages: + """Convert a ChatCompletions response to OTel GenAI semconv format.""" + output_messages: OutputMessages = [] + + for choice in getattr(response, 'choices', []): + message = getattr(choice, 'message', None) + if not message: + continue + + parts: list[MessagePart] = [] + content = getattr(message, 'content', None) + if content: + parts.append(TextPart(type='text', content=content)) + + tool_calls = getattr(message, 'tool_calls', None) + if tool_calls: + for tc in tool_calls: + func = getattr(tc, 'function', None) + if func: + parts.append( + ToolCallPart( + type='tool_call', + id=getattr(tc, 'id', ''), + name=getattr(func, 'name', ''), + arguments=getattr(func, 'arguments', None), + ) + ) + + output_msg: OutputMessage = { + 'role': cast('Role', getattr(message, 'role', 'assistant')), + 'parts': parts, + } + finish_reason = getattr(choice, 'finish_reason', None) + if finish_reason: + output_msg['finish_reason'] = str(finish_reason) + output_messages.append(output_msg) + + return output_messages + + +# --- Streaming wrappers --- + + +class _SyncStreamWrapper: + """Wraps a sync streaming response to record chunks and emit a streaming info span.""" + + def __init__( + self, + wrapped: Any, + logfire_llm: Logfire, + span_data: dict[str, Any], + original_context: Any, + ) -> None: + self._wrapped = wrapped + self._logfire_llm = logfire_llm + self._span_data = span_data + self._original_context = original_context + self._chunks: list[str] = [] + + def __enter__(self) -> _SyncStreamWrapper: + if hasattr(self._wrapped, '__enter__'): + self._wrapped.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + if hasattr(self._wrapped, '__exit__'): + self._wrapped.__exit__(*args) + + def __iter__(self) -> Iterator[Any]: + timer = self._logfire_llm._config.advanced.ns_timestamp_generator # type: ignore + start = timer() + try: + for chunk in self._wrapped: + self._record_chunk(chunk) + yield chunk + finally: + duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + has_model = self._span_data.get('request_data', {}).get('model') is not None + msg = STREAM_MSG_TEMPLATE if has_model else STREAM_MSG_TEMPLATE_NO_MODEL + with attach_context(self._original_context): + self._logfire_llm.info(msg, duration=duration, **self._get_stream_attributes()) + + def _record_chunk(self, chunk: Any) -> None: + if self._span_data.get('request_data', {}).get('model') is None: + model = getattr(chunk, 'model', None) + if model: + self._span_data['request_data']['model'] = model + self._span_data[REQUEST_MODEL] = model + for choice in getattr(chunk, 'choices', []): + delta = getattr(choice, 'delta', None) + if delta: + content = getattr(delta, 'content', None) + if content: + self._chunks.append(content) + + def _get_stream_attributes(self) -> dict[str, Any]: + result = dict(**self._span_data) + combined = ''.join(self._chunks) + if self._chunks: + result[OUTPUT_MESSAGES] = [ + OutputMessage( + role='assistant', + parts=[TextPart(type='text', content=combined)], + ) + ] + return result + + +class _AsyncStreamWrapper: + """Wraps an async streaming response to record chunks and emit a streaming info span.""" + + def __init__( + self, + wrapped: Any, + logfire_llm: Logfire, + span_data: dict[str, Any], + original_context: Any, + ) -> None: + self._wrapped = wrapped + self._logfire_llm = logfire_llm + self._span_data = span_data + self._original_context = original_context + self._chunks: list[str] = [] + + async def __aenter__(self) -> _AsyncStreamWrapper: + if hasattr(self._wrapped, '__aenter__'): + await self._wrapped.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + if hasattr(self._wrapped, '__aexit__'): + await self._wrapped.__aexit__(*args) + + async def __aiter__(self) -> AsyncIterator[Any]: + timer = self._logfire_llm._config.advanced.ns_timestamp_generator # type: ignore + start = timer() + try: + async for chunk in self._wrapped: + self._record_chunk(chunk) + yield chunk + finally: + duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS + has_model = self._span_data.get('request_data', {}).get('model') is not None + msg = STREAM_MSG_TEMPLATE if has_model else STREAM_MSG_TEMPLATE_NO_MODEL + with attach_context(self._original_context): + self._logfire_llm.info(msg, duration=duration, **self._get_stream_attributes()) + + def _record_chunk(self, chunk: Any) -> None: + if self._span_data.get('request_data', {}).get('model') is None: + model = getattr(chunk, 'model', None) + if model: + self._span_data['request_data']['model'] = model + self._span_data[REQUEST_MODEL] = model + for choice in getattr(chunk, 'choices', []): + delta = getattr(choice, 'delta', None) + if delta: + content = getattr(delta, 'content', None) + if content: + self._chunks.append(content) + + def _get_stream_attributes(self) -> dict[str, Any]: + result = dict(**self._span_data) + combined = ''.join(self._chunks) + if self._chunks: + result[OUTPUT_MESSAGES] = [ + OutputMessage( + role='assistant', + parts=[TextPart(type='text', content=combined)], + ) + ] + return result diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 472910163..f3b9e9429 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -79,6 +79,14 @@ import openai import pydantic_ai.models import requests + from azure.ai.inference import ( + ChatCompletionsClient as AzureChatCompletionsClient, + EmbeddingsClient as AzureEmbeddingsClient, + ) + from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncAzureChatCompletionsClient, + EmbeddingsClient as AsyncAzureEmbeddingsClient, + ) from django.http import HttpRequest, HttpResponse from fastapi import FastAPI from flask.app import Flask @@ -1389,6 +1397,76 @@ def instrument_anthropic( is_async_client, ) + def instrument_azure_ai_inference( + self, + azure_ai_inference_client: ( + AzureChatCompletionsClient + | AzureEmbeddingsClient + | AsyncAzureChatCompletionsClient + | AsyncAzureEmbeddingsClient + | type[AzureChatCompletionsClient] + | type[AzureEmbeddingsClient] + | type[AsyncAzureChatCompletionsClient] + | type[AsyncAzureEmbeddingsClient] + | None + ) = None, + *, + suppress_other_instrumentation: bool = True, + ) -> AbstractContextManager[None]: + """Instrument an Azure AI Inference client so that spans are automatically created for each request. + + Supports both the sync and async clients from the + [`azure-ai-inference`](https://pypi.org/project/azure-ai-inference/) package: + + - [`ChatCompletionsClient.complete`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.chatcompletionsclient) - with and without `stream=True` + - [`EmbeddingsClient.embed`](https://learn.microsoft.com/python/api/azure-ai-inference/azure.ai.inference.embeddingsclient) + + Example usage: + + ```python skip-run="true" skip-reason="external-connection" + from azure.ai.inference import ChatCompletionsClient + from azure.core.credentials import AzureKeyCredential + + import logfire + + client = ChatCompletionsClient( + endpoint='https://my-endpoint.inference.ai.azure.com', + credential=AzureKeyCredential('my-api-key'), + ) + + logfire.configure() + logfire.instrument_azure_ai_inference(client) + + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is four plus five?'}], + ) + print(response.choices[0].message.content) + ``` + + Args: + azure_ai_inference_client: The Azure AI Inference client or class to instrument: + + - `None` (the default) to instrument all Azure AI Inference client classes. + - A `ChatCompletionsClient` or `EmbeddingsClient` class or instance (sync or async). + + suppress_other_instrumentation: If True, suppress any other OTEL instrumentation that may be otherwise + enabled. In reality, this means the Azure Core tracing instrumentation, which could otherwise be + called since the Azure SDK uses its own pipeline to make HTTP requests. + + Returns: + A context manager that will revert the instrumentation when exited. + Use of this context manager is optional. + """ + from .integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference + + self._warn_if_not_initialized_for_instrumentation() + return instrument_azure_ai_inference( + self, + azure_ai_inference_client, + suppress_other_instrumentation, + ) + def instrument_google_genai(self, **kwargs: Any): """Instrument the [Google Gen AI SDK (`google-genai`)](https://googleapis.github.io/python-genai/). diff --git a/mkdocs.yml b/mkdocs.yml index 67b619d78..f79d650c9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -114,6 +114,7 @@ nav: - OpenAI: integrations/llms/openai.md - Google Gen AI: integrations/llms/google-genai.md - Anthropic: integrations/llms/anthropic.md + - Azure AI Inference: integrations/llms/azure-ai-inference.md - LangChain: integrations/llms/langchain.md - LiteLLM: integrations/llms/litellm.md - DSPy: integrations/llms/dspy.md diff --git a/pyproject.toml b/pyproject.toml index e4e1a35cb..333ab4d9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ requests = ["opentelemetry-instrumentation-requests >= 0.42b0"] mysql = ["opentelemetry-instrumentation-mysql >= 0.42b0"] sqlite3 = ["opentelemetry-instrumentation-sqlite3 >= 0.42b0"] aws-lambda = ["opentelemetry-instrumentation-aws-lambda >= 0.42b0"] +azure-ai-inference = ["azure-ai-inference >= 1.0.0b1"] google-genai = ["opentelemetry-instrumentation-google-genai >= 0.4b0"] litellm = ["openinference-instrumentation-litellm >= 0"] dspy = ["openinference-instrumentation-dspy >= 0"] @@ -159,6 +160,7 @@ dev = [ "cryptography >= 44.0.0", "cloudpickle>=3.0.0", "anthropic>=0.27.0", + "azure-ai-inference>=1.0.0b1", "sqlmodel>=0.0.15", "mypy>=1.10.0", "celery>=5.4.0", diff --git a/tests/otel_integrations/test_azure_ai_inference.py b/tests/otel_integrations/test_azure_ai_inference.py new file mode 100644 index 000000000..43bdd0c93 --- /dev/null +++ b/tests/otel_integrations/test_azure_ai_inference.py @@ -0,0 +1,1197 @@ +# pyright: reportCallIssue=false, reportArgumentType=false, reportPrivateUsage=false +from __future__ import annotations as _annotations + +from collections.abc import AsyncIterator, Iterator +from datetime import datetime +from typing import Any + +import pytest +from azure.ai.inference.models import ( + ChatChoice, + ChatCompletions, + ChatResponseMessage, + CompletionsUsage, + EmbeddingItem, + EmbeddingsResult, + EmbeddingsUsage, + StreamingChatChoiceUpdate, + StreamingChatCompletionsUpdate, + StreamingChatResponseMessageUpdate, +) +from inline_snapshot import snapshot + +import logfire +from logfire.testing import TestExporter + + +def _make_chat_response( + content: str = 'Nine', + finish_reason: str = 'stop', + tool_calls: list[Any] | None = None, +) -> ChatCompletions: + message_kwargs: dict[str, Any] = {'role': 'assistant', 'content': content} + if tool_calls is not None: + message_kwargs['tool_calls'] = tool_calls + return ChatCompletions( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason=finish_reason, + message=ChatResponseMessage(**message_kwargs), + ) + ], + usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +def _make_tool_response() -> ChatCompletions: + return _make_chat_response( + content='', + finish_reason='tool_calls', + tool_calls=[ + { + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, + } + ], + ) + + +def _make_streaming_chunks() -> list[StreamingChatCompletionsUpdate]: + return [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate( + index=0, delta=StreamingChatResponseMessageUpdate(role='assistant', content='') + ) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='The answer')) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content=' is secret')) + ], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[ + StreamingChatChoiceUpdate( + index=0, finish_reason='stop', delta=StreamingChatResponseMessageUpdate(content='') + ) + ], + ), + ] + + +def _make_embed_response() -> EmbeddingsResult: + return EmbeddingsResult( + id='test-id', + model='text-embedding-ada-002', + data=[EmbeddingItem(embedding=[0.1, 0.2, 0.3], index=0)], + usage=EmbeddingsUsage(prompt_tokens=5, total_tokens=5), + ) + + +class MockChatCompletionsClient: + """Mock ChatCompletionsClient that returns preconfigured responses.""" + + __module__ = 'azure.ai.inference' + + def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) -> None: + self._response = response or _make_chat_response() + self._stream_chunks = stream_chunks + + def complete(self, *args: Any, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return iter(self._stream_chunks or _make_streaming_chunks()) + return self._response + + +class MockAsyncChatCompletionsClient: + """Mock async ChatCompletionsClient.""" + + __module__ = 'azure.ai.inference.aio' + + def __init__(self, response: Any = None, stream_chunks: list[Any] | None = None) -> None: + self._response = response or _make_chat_response() + self._stream_chunks = stream_chunks + + async def complete(self, *args: Any, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return _async_iter(self._stream_chunks or _make_streaming_chunks()) + return self._response + + +class MockEmbeddingsClient: + """Mock EmbeddingsClient.""" + + __module__ = 'azure.ai.inference' + + def __init__(self, response: Any = None) -> None: + self._response = response or _make_embed_response() + + def embed(self, **kwargs: Any) -> Any: + return self._response + + +class MockAsyncEmbeddingsClient: + """Mock async EmbeddingsClient.""" + + __module__ = 'azure.ai.inference.aio' + + def __init__(self, response: Any = None) -> None: + self._response = response or _make_embed_response() + + async def embed(self, **kwargs: Any) -> Any: + return self._response + + +async def _async_iter(items: list[Any]) -> Any: + for item in items: + yield item + + +def test_sync_chat(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[ + {'role': 'system', 'content': 'You are helpful.'}, + {'role': 'user', 'content': 'What is four plus five?'}, + ], + temperature=0.5, + ) + assert response.choices[0].message.content == 'Nine' + assert exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'Chat completion with {request_data[model]!r}', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat', + 'code.lineno': 123, + 'request_data': {'model': 'gpt-4'}, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.request.temperature': 0.5, + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'What is four plus five?'}]} + ], + 'gen_ai.system_instructions': [{'type': 'text', 'content': 'You are helpful.'}], + 'logfire.msg_template': 'Chat completion with {request_data[model]!r}', + 'logfire.msg': "Chat completion with 'gpt-4'", + 'logfire.tags': ('LLM',), + 'logfire.span_type': 'span', + 'gen_ai.output.messages': [ + { + 'role': 'assistant', + 'parts': [{'type': 'text', 'content': 'Nine'}], + 'finish_reason': 'CompletionsFinishReason.STOPPED', + } + ], + 'gen_ai.response.model': 'gpt-4', + 'gen_ai.response.id': 'test-id', + 'gen_ai.usage.input_tokens': 10, + 'gen_ai.usage.output_tokens': 5, + 'gen_ai.response.finish_reasons': ['CompletionsFinishReason.STOPPED'], + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.request.temperature': {}, + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.system_instructions': {'type': 'array'}, + 'gen_ai.output.messages': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'role': { + 'type': 'string', + 'title': 'ChatRole', + 'x-python-datatype': 'Enum', + 'enum': ['system', 'user', 'assistant', 'tool', 'developer'], + } + }, + }, + }, + 'gen_ai.response.model': {}, + 'gen_ai.response.id': {}, + 'gen_ai.usage.input_tokens': {}, + 'gen_ai.usage.output_tokens': {}, + 'gen_ai.response.finish_reasons': {'type': 'array'}, + }, + }, + }, + } + ] + ) + + +def test_sync_chat_streaming(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Tell me a secret'}], + stream=True, + ) + chunks = list(response) + assert len(chunks) == 4 + assert exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'Chat completion with {request_data[model]!r}', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat_streaming', + 'code.lineno': 123, + 'request_data': {'model': 'gpt-4'}, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'Tell me a secret'}]} + ], + 'logfire.msg_template': 'Chat completion with {request_data[model]!r}', + 'logfire.msg': "Chat completion with 'gpt-4'", + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.input.messages': {'type': 'array'}, + }, + }, + 'logfire.tags': ('LLM',), + 'logfire.span_type': 'span', + 'gen_ai.response.model': 'gpt-4', + }, + }, + { + 'name': 'streaming response from {request_data[model]!r} took {duration:.2f}s', + 'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'parent': None, + 'start_time': 5000000000, + 'end_time': 5000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 9, + 'logfire.msg_template': 'streaming response from {request_data[model]!r} took {duration:.2f}s', + 'logfire.msg': "streaming response from 'gpt-4' took 1.00s", + 'code.filepath': 'test_azure_ai_inference.py', + 'code.function': 'test_sync_chat_streaming', + 'code.lineno': 123, + 'duration': 1.0, + 'request_data': {'model': 'gpt-4'}, + 'gen_ai.provider.name': 'azure.ai.inference', + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'gpt-4', + 'gen_ai.input.messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'Tell me a secret'}]} + ], + 'gen_ai.output.messages': [ + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'The answer is secret'}]} + ], + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'duration': {}, + 'request_data': {'type': 'object'}, + 'gen_ai.provider.name': {}, + 'gen_ai.operation.name': {}, + 'gen_ai.request.model': {}, + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.output.messages': {'type': 'array'}, + }, + }, + 'logfire.tags': ('LLM',), + 'gen_ai.response.model': 'gpt-4', + }, + }, + ] + ) + + +def test_sync_chat_tool_calls(exporter: TestExporter) -> None: + client = MockChatCompletionsClient(response=_make_tool_response()) + with logfire.instrument_azure_ai_inference(client): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is the weather?'}], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + # Check tool calls in semconv output + output_msgs = attrs['gen_ai.output.messages'] + assert len(output_msgs) == 1 + tool_part = output_msgs[0]['parts'][0] + assert tool_part['type'] == 'tool_call' + assert tool_part['name'] == 'get_weather' + + +@pytest.mark.anyio +async def test_async_chat(exporter: TestExporter) -> None: + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'What is four plus five?'}], + ) + assert response.choices[0].message.content == 'Nine' + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.response.model'] == 'gpt-4' + assert spans[0]['attributes']['gen_ai.usage.input_tokens'] == 10 + + +@pytest.mark.anyio +async def test_async_chat_streaming(exporter: TestExporter) -> None: + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Tell me a secret'}], + stream=True, + ) + chunks = [chunk async for chunk in response] + assert len(chunks) == 4 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + # First span: the request + assert spans[0]['attributes']['logfire.msg'] == "Chat completion with 'gpt-4'" + # Second span: streaming info + assert 'streaming response from' in spans[1]['attributes']['logfire.msg'] + + +def test_sync_embeddings(exporter: TestExporter) -> None: + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.embed( + model='text-embedding-ada-002', + input=['Hello world'], + ) + assert len(response.data) == 1 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + assert attrs['gen_ai.provider.name'] == 'azure.ai.inference' + assert attrs['gen_ai.operation.name'] == 'embeddings' + assert attrs['gen_ai.request.model'] == 'text-embedding-ada-002' + assert attrs['gen_ai.response.model'] == 'text-embedding-ada-002' + assert attrs['gen_ai.usage.input_tokens'] == 5 + + +@pytest.mark.anyio +async def test_async_embeddings(exporter: TestExporter) -> None: + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.embed( + model='text-embedding-ada-002', + input=['Hello world'], + ) + assert len(response.data) == 1 + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert spans[0]['attributes']['gen_ai.operation.name'] == 'embeddings' + + +def test_uninstrumentation(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + # After exiting context, client should be uninstrumented + exporter.clear() + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 0 + + +def test_double_instrumentation(exporter: TestExporter) -> None: + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + # Should only produce one span (not double-instrumented) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_no_model_backfill(exporter: TestExporter) -> None: + """When request has no model, backfill from response.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + # Model backfilled from response + assert attrs['logfire.msg'] == "Chat completion with 'gpt-4'" + assert attrs['gen_ai.request.model'] == 'gpt-4' + assert attrs['gen_ai.response.model'] == 'gpt-4' + + +def test_no_model_streaming_backfill(exporter: TestExporter) -> None: + """When streaming request has no model, backfill from first chunk.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + # Streaming info span should have model from chunks + assert spans[1]['attributes']['request_data']['model'] == 'gpt-4' + assert spans[1]['attributes']['gen_ai.request.model'] == 'gpt-4' + + +@pytest.mark.anyio +async def test_no_model_async_streaming_backfill(exporter: TestExporter) -> None: + """When async streaming request has no model, backfill from first chunk.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + assert spans[1]['attributes']['request_data']['model'] == 'gpt-4' + + +def test_no_model_embed_backfill(exporter: TestExporter) -> None: + """When embed request has no model, backfill from response.""" + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + client.embed(input=['Hello']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + attrs = spans[0]['attributes'] + assert attrs['logfire.msg'] == "Embeddings with 'text-embedding-ada-002'" + assert attrs['gen_ai.request.model'] == 'text-embedding-ada-002' + + +@pytest.mark.anyio +async def test_no_model_async_embed_backfill(exporter: TestExporter) -> None: + """When async embed request has no model, backfill from response.""" + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client): + await client.embed(input=['Hello']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + assert attrs_msg(spans[0]) == "Embeddings with 'text-embedding-ada-002'" + + +def attrs_msg(span: dict[str, Any]) -> str: + return span['attributes']['logfire.msg'] + + +def test_suppress_false(exporter: TestExporter) -> None: + """Test with suppress_other_instrumentation=False.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + +@pytest.mark.anyio +async def test_suppress_false_async(exporter: TestExporter) -> None: + """Test with suppress_other_instrumentation=False for async.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + await client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_suppress_false_embed(exporter: TestExporter) -> None: + """Test embed with suppress=False.""" + client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + client.embed(model='text-embedding-ada-002', input=['Hi']) + assert len(exporter.exported_spans_as_dict()) == 1 + + +@pytest.mark.anyio +async def test_suppress_false_async_embed(exporter: TestExporter) -> None: + """Test async embed with suppress=False.""" + client = MockAsyncEmbeddingsClient() + with logfire.instrument_azure_ai_inference(client, suppress_other_instrumentation=False): + await client.embed(model='text-embedding-ada-002', input=['Hi']) + assert len(exporter.exported_spans_as_dict()) == 1 + + +def test_list_instrumentation(exporter: TestExporter) -> None: + """Test instrumenting a list of clients.""" + chat_client = MockChatCompletionsClient() + embed_client = MockEmbeddingsClient() + with logfire.instrument_azure_ai_inference([chat_client, embed_client]): + chat_client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + embed_client.embed(model='text-embedding-ada-002', input=['Hello']) + assert len(exporter.exported_spans_as_dict()) == 2 + + # After exiting, both should be uninstrumented + exporter.clear() + chat_client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + embed_client.embed(model='text-embedding-ada-002', input=['Hello']) + assert len(exporter.exported_spans_as_dict()) == 0 + + +def test_request_parameters(exporter: TestExporter) -> None: + """Test that all request parameters are captured.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + temperature=0.7, + max_tokens=100, + top_p=0.9, + frequency_penalty=0.5, + presence_penalty=0.3, + seed=42, + stop=['\n', 'END'], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + attrs = spans[0]['attributes'] + assert attrs['gen_ai.request.temperature'] == 0.7 + assert attrs['gen_ai.request.max_tokens'] == 100 + assert attrs['gen_ai.request.top_p'] == 0.9 + assert attrs['gen_ai.request.frequency_penalty'] == 0.5 + assert attrs['gen_ai.request.presence_penalty'] == 0.3 + assert attrs['gen_ai.request.seed'] == 42 + assert attrs['gen_ai.request.stop_sequences'] == ['\n', 'END'] + + +def test_extract_params_body_style(exporter: TestExporter) -> None: + """Test that body-style parameters are extracted.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete(body={'model': 'gpt-4', 'messages': [{'role': 'user', 'content': 'Hi'}]}) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert spans[0]['attributes']['gen_ai.request.model'] == 'gpt-4' + + +def test_content_item_conversion() -> None: + """Test conversion of multimodal content items.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + { + 'role': 'user', + 'content': [ + 'plain string item', + {'type': 'text', 'text': 'text item'}, + {'type': 'image_url', 'image_url': {'url': 'https://example.com/img.png'}}, + {'type': 'input_audio', 'input_audio': {'data': 'base64data', 'format': 'mp3'}}, + ], + }, + ] + input_msgs, _ = convert_messages_to_semconv(messages) + parts = input_msgs[0]['parts'] + assert parts[0] == {'type': 'text', 'content': 'plain string item'} + assert parts[1] == {'type': 'text', 'content': 'text item'} + assert parts[2] == {'type': 'uri', 'uri': 'https://example.com/img.png', 'modality': 'image'} + assert parts[3] == {'type': 'blob', 'content': 'base64data', 'media_type': 'audio/mp3', 'modality': 'audio'} + + +def test_stream_context_manager(exporter: TestExporter) -> None: + """Test that sync stream wrapper supports context manager protocol.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + # Use as context manager + with response: + for _ in response: + pass + assert len(exporter.exported_spans_as_dict()) == 2 + + +@pytest.mark.anyio +async def test_async_stream_context_manager(exporter: TestExporter) -> None: + """Test that async stream wrapper supports async context manager protocol.""" + client = MockAsyncChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async with response: + async for _ in response: + pass + assert len(exporter.exported_spans_as_dict()) == 2 + + +def test_positional_arg_extraction() -> None: + """Test that _extract_params handles positional args correctly.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import _extract_params + + # Single positional arg with 'messages' key + result = _extract_params(({'model': 'gpt-4', 'messages': [{'role': 'user', 'content': 'Hi'}]},), {}) + assert result['model'] == 'gpt-4' + + # Multiple args, first doesn't match, second does (covers loop iteration) + result = _extract_params(('not-a-dict', {'messages': [{'role': 'user', 'content': 'Hi'}]}), {}) + assert 'messages' in result + + # No matching arg, falls back to kwargs + result = _extract_params(('not-a-dict',), {'model': 'gpt-4'}) + assert result == {'model': 'gpt-4'} + + +def test_tools_with_as_dict(exporter: TestExporter) -> None: + """Test that tool objects with as_dict() are handled.""" + + class MockTool: + def as_dict(self) -> dict[str, Any]: + return {'type': 'function', 'function': {'name': 'my_tool', 'parameters': {}}} + + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + tools=[MockTool()], + ) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + tool_defs = spans[0]['attributes']['gen_ai.tool.definitions'] + assert tool_defs == [{'type': 'function', 'function': {'name': 'my_tool', 'parameters': {}}}] + + +def test_backfill_no_model_in_response(exporter: TestExporter) -> None: + """Test backfill when response also has no model.""" + response = ChatCompletions( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason='stop', + message=ChatResponseMessage(role='assistant', content='Hello'), + ) + ], + usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + client = MockChatCompletionsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.complete(messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + # Model stays None - no backfill + assert spans[0]['attributes']['logfire.msg'] == 'Chat completion' + + +def test_minimal_chat_response(exporter: TestExporter) -> None: + """Test response with no model, no id, no usage, no finish_reason. + + Exercises the false branches in _on_chat_response. + """ + response = ChatCompletions( + id=None, + model=None, + created=datetime(2024, 1, 1), + choices=[ + ChatChoice( + index=0, + finish_reason=None, + message=ChatResponseMessage(role='assistant', content='Hi'), + ) + ], + usage=None, + ) + client = MockChatCompletionsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_choice_without_message() -> None: + """Test response with a choice that has no message. + + Exercises the `if not message: continue` path in convert_response_to_semconv. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class FakeChoice: + message = None + finish_reason = 'stop' + + class FakeResponse: + choices = [FakeChoice()] + + output = convert_response_to_semconv(FakeResponse()) + assert output == [] + + +def test_minimal_embed_response(exporter: TestExporter) -> None: + """Test embed response with no model, no id, no usage. + + Exercises the false branches in _on_embed_response. + """ + response = EmbeddingsResult( + id=None, + model=None, + data=[EmbeddingItem(embedding=[0.1], index=0)], + usage=None, + ) + client = MockEmbeddingsClient(response=response) + with logfire.instrument_azure_ai_inference(client): + client.embed(model='text-embedding-ada-002', input=['Hi']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_response_empty_choices(exporter: TestExporter) -> None: + """Test response with empty choices list. + + Exercises the false branch of `if output_messages:` in _on_chat_response. + """ + + class FakeResponse: + id = 'test-id' + model = 'gpt-4' + choices: list[Any] = [] + usage = None + + client = MockChatCompletionsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_usage_with_none_tokens(exporter: TestExporter) -> None: + """Test response with usage but None prompt/completion tokens. + + Exercises false branches of prompt_tokens/completion_tokens checks. + """ + + class FakeUsage: + prompt_tokens = None + completion_tokens = None + + class FakeMessage: + role = 'assistant' + content = 'Hi' + tool_calls = None + + class FakeChoice: + index = 0 + finish_reason = None + message = FakeMessage() + + class FakeResponse: + id = None + model = None + choices = [FakeChoice()] + usage = FakeUsage() + + client = MockChatCompletionsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[{'role': 'user', 'content': 'Hi'}]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_embed_usage_with_none_tokens(exporter: TestExporter) -> None: + """Test embed response with usage but None prompt_tokens. + + Exercises false branch of prompt_tokens check in _on_embed_response. + """ + + class FakeUsage: + prompt_tokens = None + + class FakeResponse: + id = None + model = None + data = [] + usage = FakeUsage() + + client = MockEmbeddingsClient(response=FakeResponse()) + with logfire.instrument_azure_ai_inference(client): + client.embed(model='text-embedding-ada-002', input=['Hi']) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 1 + + +def test_no_messages_in_request(exporter: TestExporter) -> None: + """Test chat completion with no messages.""" + client = MockChatCompletionsClient() + with logfire.instrument_azure_ai_inference(client): + client.complete(model='gpt-4', messages=[]) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert 'gen_ai.input.messages' not in spans[0]['attributes'] + + +def test_system_message_non_string_content() -> None: + """Test system message with non-string content. + + Exercises the false branch of isinstance(content, str) for system messages. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + {'role': 'system', 'content': None}, + {'role': 'user', 'content': 'Hi'}, + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert system_instructions == [] + assert len(input_msgs) == 1 + + +def test_response_no_output_messages() -> None: + """Test convert_response_to_semconv with empty choices. + + Exercises the false branch of `if output_messages:` in _on_chat_response. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class EmptyResponse: + choices = [] + + output = convert_response_to_semconv(EmptyResponse()) + assert output == [] + + +def test_response_tool_call_no_function() -> None: + """Test response tool call without function attribute. + + Exercises the false branch of `if func:` in convert_response_to_semconv. + """ + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_response_to_semconv + + class FakeToolCall: + id = 'tc1' + function = None + + class FakeMessage: + role = 'assistant' + content = None + tool_calls = [FakeToolCall()] + + class FakeChoice: + message = FakeMessage() + finish_reason = None + + class FakeResponse: + choices = [FakeChoice()] + + output = convert_response_to_semconv(FakeResponse()) + assert len(output) == 1 + assert output[0]['parts'] == [] + + +def test_stream_wrapped_with_context_manager(exporter: TestExporter) -> None: + """Test sync stream where wrapped object has __enter__/__exit__.""" + + class ContextManagerIterator: + def __init__(self, items: list[Any]) -> None: + self.items = items + self.entered = False + self.exited = False + + def __enter__(self) -> ContextManagerIterator: + self.entered = True + return self + + def __exit__(self, *args: Any) -> None: + self.exited = True + + def __iter__(self) -> Iterator[Any]: + return iter(self.items) + + chunks = _make_streaming_chunks() + wrapped = ContextManagerIterator(chunks) + + class MockChatCompletionsClientWithCM: + __module__ = 'azure.ai.inference' + + def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return wrapped + return _make_chat_response() + + client = MockChatCompletionsClientWithCM() + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + with response: + for _ in response: + pass + assert wrapped.entered + assert wrapped.exited + assert len(exporter.exported_spans_as_dict()) == 2 + + +@pytest.mark.anyio +async def test_async_stream_wrapped_with_context_manager(exporter: TestExporter) -> None: + """Test async stream where wrapped object has __aenter__/__aexit__.""" + + class AsyncContextManagerIterator: + def __init__(self, items: list[Any]) -> None: + self.items = items + self.entered = False + self.exited = False + + async def __aenter__(self) -> AsyncContextManagerIterator: + self.entered = True + return self + + async def __aexit__(self, *args: Any) -> None: + self.exited = True + + async def __aiter__(self) -> AsyncIterator[Any]: + for item in self.items: + yield item + + chunks = _make_streaming_chunks() + wrapped = AsyncContextManagerIterator(chunks) + + class MockAsyncChatCompletionsClientWithCM: + __module__ = 'azure.ai.inference.aio' + + async def complete(self, **kwargs: Any) -> Any: + if kwargs.get('stream'): + return wrapped + return _make_chat_response() + + client = MockAsyncChatCompletionsClientWithCM() + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async with response: + async for _ in response: + pass + assert wrapped.entered + assert wrapped.exited + assert len(exporter.exported_spans_as_dict()) == 2 + + +def test_streaming_empty_chunks(exporter: TestExporter) -> None: + """Test streaming with chunks that have no choices or no content.""" + empty_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=None)], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content=None))], + ), + ] + client = MockChatCompletionsClient(stream_chunks=empty_chunks) + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + model='gpt-4', + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + # Streaming info span should NOT have output messages (no actual content) + assert 'gen_ai.output.messages' not in spans[1]['attributes'] + + +@pytest.mark.anyio +async def test_async_streaming_empty_chunks(exporter: TestExporter) -> None: + """Test async streaming with chunks that have no choices or no content.""" + empty_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[], + ), + StreamingChatCompletionsUpdate( + id='test-id', + model='gpt-4', + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=None)], + ), + ] + client = MockAsyncChatCompletionsClient(stream_chunks=empty_chunks) + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert 'gen_ai.output.messages' not in spans[1]['attributes'] + + +def test_streaming_no_model_chunks(exporter: TestExporter) -> None: + """Test streaming where both request and chunk have no model. + + Exercises false branch of `if model:` in _record_chunk (sync). + """ + no_model_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='Hi'))], + ), + ] + client = MockChatCompletionsClient(stream_chunks=no_model_chunks) + with logfire.instrument_azure_ai_inference(client): + response = client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + list(response) + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + + +@pytest.mark.anyio +async def test_async_streaming_no_model_chunks(exporter: TestExporter) -> None: + """Test async streaming where both request and chunk have no model. + + Exercises false branch of `if model:` in _record_chunk (async). + """ + no_model_chunks = [ + StreamingChatCompletionsUpdate( + id='test-id', + model=None, + created=datetime(2024, 1, 1), + choices=[StreamingChatChoiceUpdate(index=0, delta=StreamingChatResponseMessageUpdate(content='Hi'))], + ), + ] + client = MockAsyncChatCompletionsClient(stream_chunks=no_model_chunks) + with logfire.instrument_azure_ai_inference(client): + response = await client.complete( + messages=[{'role': 'user', 'content': 'Hi'}], + stream=True, + ) + async for _ in response: + pass + spans = exporter.exported_spans_as_dict(parse_json_attributes=True) + assert len(spans) == 2 + + +def test_message_conversion_with_typed_objects() -> None: + """Test that Azure SDK typed message objects are converted correctly.""" + from azure.ai.inference.models import SystemMessage, UserMessage + + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + SystemMessage(content='You are helpful.'), + UserMessage(content='Hello'), + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert system_instructions == [{'type': 'text', 'content': 'You are helpful.'}] + assert input_msgs == [{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]}] + + +def test_message_conversion_with_tool_messages() -> None: + """Test that tool messages are converted correctly.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import convert_messages_to_semconv + + messages = [ + {'role': 'user', 'content': 'What is the weather?'}, + { + 'role': 'assistant', + 'content': '', + 'tool_calls': [ + { + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': '{"city": "London"}'}, + }, + ], + }, + {'role': 'tool', 'content': '72F', 'tool_call_id': 'call_1'}, + ] + input_msgs, system_instructions = convert_messages_to_semconv(messages) + assert len(input_msgs) == 3 + assert system_instructions == [] + # User message + assert input_msgs[0] == {'role': 'user', 'parts': [{'type': 'text', 'content': 'What is the weather?'}]} + # Assistant with tool call + assert input_msgs[1]['role'] == 'assistant' + tool_part: Any = input_msgs[1]['parts'][0] + assert tool_part['type'] == 'tool_call' + assert tool_part['name'] == 'get_weather' + # Tool response + assert input_msgs[2]['role'] == 'tool' + tool_resp: Any = input_msgs[2]['parts'][0] + assert tool_resp['type'] == 'tool_call_response' + assert tool_resp['id'] == 'call_1' + assert tool_resp['response'] == '72F' + + +def test_global_instrumentation() -> None: + """Test passing None instruments all client classes via the integration module.""" + from logfire._internal.integrations.llm_providers.azure_ai_inference import instrument_azure_ai_inference + + # Call with client=None so the integration module resolves client classes itself + cm = instrument_azure_ai_inference(logfire.DEFAULT_LOGFIRE_INSTANCE, None, True) + # Just verify it returns a context manager without error + with cm: + pass diff --git a/uv.lock b/uv.lock index fcc647ece..102a21162 100644 --- a/uv.lock +++ b/uv.lock @@ -425,6 +425,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "azure-ai-inference" +version = "1.0.0b9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core" }, + { name = "isodate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" }, +] + +[[package]] +name = "azure-core" +version = "1.38.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/fe/5c7710bc611a4070d06ba801de9a935cc87c3d4b689c644958047bdf2cba/azure_core-1.38.2.tar.gz", hash = "sha256:67562857cb979217e48dc60980243b61ea115b77326fa93d83b729e7ff0482e7", size = 363734, upload-time = "2026-02-18T19:33:05.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/23/6371a551800d3812d6019cd813acd985f9fac0fedc1290129211a73da4ae/azure_core-1.38.2-py3-none-any.whl", hash = "sha256:074806c75cf239ea284a33a66827695ef7aeddac0b4e19dda266a93e4665ead9", size = 217957, upload-time = "2026-02-18T19:33:07.696Z" }, +] + [[package]] name = "babel" version = "2.18.0" @@ -2442,6 +2469,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/25/0e84a6322e5fdb1bf67870b2269151449f4894987b26c78718918dd64ea6/inline_snapshot-0.32.0-py3-none-any.whl", hash = "sha256:b522ae2c891f666e80213c5f9677ec6fd4a2a7d334ab9d6ce745675bec6a40f0", size = 84087, upload-time = "2026-02-13T19:51:52.604Z" }, ] +[[package]] +name = "isodate" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" }, +] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -3258,6 +3294,9 @@ asyncpg = [ aws-lambda = [ { name = "opentelemetry-instrumentation-aws-lambda" }, ] +azure-ai-inference = [ + { name = "azure-ai-inference" }, +] celery = [ { name = "opentelemetry-instrumentation-celery" }, ] @@ -3338,6 +3377,7 @@ dev = [ { name = "anthropic" }, { name = "asyncpg" }, { name = "attrs" }, + { name = "azure-ai-inference" }, { name = "boto3" }, { name = "botocore" }, { name = "celery" }, @@ -3463,6 +3503,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = ">=1.0.0b1" }, { name = "executing", specifier = ">=2.0.1" }, { name = "httpx", marker = "extra == 'datasets'", specifier = ">=0.27.2" }, { name = "openinference-instrumentation-dspy", marker = "extra == 'dspy'", specifier = ">=0" }, @@ -3504,7 +3545,7 @@ requires-dist = [ { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.1" }, { name = "typing-extensions", specifier = ">=4.1.0" }, ] -provides-extras = ["system-metrics", "asgi", "wsgi", "aiohttp", "aiohttp-client", "aiohttp-server", "celery", "django", "fastapi", "flask", "httpx", "starlette", "sqlalchemy", "asyncpg", "psycopg", "psycopg2", "pymongo", "redis", "requests", "mysql", "sqlite3", "aws-lambda", "google-genai", "litellm", "dspy", "datasets", "variables"] +provides-extras = ["system-metrics", "asgi", "wsgi", "aiohttp", "aiohttp-client", "aiohttp-server", "celery", "django", "fastapi", "flask", "httpx", "starlette", "sqlalchemy", "asyncpg", "psycopg", "psycopg2", "pymongo", "redis", "requests", "mysql", "sqlite3", "aws-lambda", "azure-ai-inference", "google-genai", "litellm", "dspy", "datasets", "variables"] [package.metadata.requires-dev] dev = [ @@ -3513,6 +3554,7 @@ dev = [ { name = "anthropic", specifier = ">=0.27.0" }, { name = "asyncpg", specifier = ">=0.30.0" }, { name = "attrs", specifier = ">=23.1.0" }, + { name = "azure-ai-inference", specifier = ">=1.0.0b1" }, { name = "boto3", specifier = ">=1.28.57" }, { name = "botocore", specifier = ">=1.31.57" }, { name = "celery", specifier = ">=5.4.0" },