diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index 7a59708c7..abef523bb 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -191,6 +191,9 @@ def instrument_google_genai(self, *args, **kwargs) -> None: ... def instrument_litellm(self, *args, **kwargs) -> None: ... + def instrument_langchain(self, *args, **kwargs) -> ContextManager[None]: + return nullcontext() + def instrument_aiohttp_client(self, *args, **kwargs) -> None: ... def instrument_aiohttp_server(self, *args, **kwargs) -> None: ... @@ -229,6 +232,7 @@ def shutdown(self, *args, **kwargs) -> None: ... instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm + instrument_langchain = DEFAULT_LOGFIRE_INSTANCE.instrument_langchain instrument_asyncpg = DEFAULT_LOGFIRE_INSTANCE.instrument_asyncpg instrument_print = DEFAULT_LOGFIRE_INSTANCE.instrument_print instrument_celery = DEFAULT_LOGFIRE_INSTANCE.instrument_celery diff --git a/logfire-api/logfire_api/__init__.pyi b/logfire-api/logfire_api/__init__.pyi index fe9eba578..328c0b4b6 100644 --- a/logfire-api/logfire_api/__init__.pyi +++ b/logfire-api/logfire_api/__init__.pyi @@ -33,6 +33,7 @@ instrument_openai_agents = DEFAULT_LOGFIRE_INSTANCE.instrument_openai_agents instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm +instrument_langchain = DEFAULT_LOGFIRE_INSTANCE.instrument_langchain instrument_print = DEFAULT_LOGFIRE_INSTANCE.instrument_print instrument_asyncpg = DEFAULT_LOGFIRE_INSTANCE.instrument_asyncpg instrument_httpx = DEFAULT_LOGFIRE_INSTANCE.instrument_httpx diff --git a/logfire-api/logfire_api/_internal/integrations/langchain.pyi b/logfire-api/logfire_api/_internal/integrations/langchain.pyi new file mode 100644 index 000000000..1a6eba9dc --- /dev/null +++ b/logfire-api/logfire_api/_internal/integrations/langchain.pyi @@ -0,0 +1,4 @@ +import logfire +from contextlib import AbstractContextManager + +def instrument_langchain(logfire_instance: logfire.Logfire) -> AbstractContextManager[None]: ... diff --git a/logfire-api/logfire_api/_internal/main.pyi b/logfire-api/logfire_api/_internal/main.pyi index 841618427..523c2bba6 100644 --- a/logfire-api/logfire_api/_internal/main.pyi +++ b/logfire-api/logfire_api/_internal/main.pyi @@ -606,6 +606,17 @@ class Logfire: [`openinference-instrumentation-litellm`](https://pypi.org/project/openinference-instrumentation-litellm/) package, to which it passes `**kwargs`. """ + def instrument_langchain(self) -> AbstractContextManager[None]: + """Instrument LangChain to capture full execution hierarchy with tool definitions. + + This patches LangChain's BaseCallbackManager to inject a callback handler + that captures the complete execution hierarchy including chains, tools, + retrievers, and LLM calls with tool definitions. + + Returns: + A context manager that will revert the instrumentation when exited. + Use of this context manager is optional. + """ def instrument_print(self) -> AbstractContextManager[None]: """Instrument the built-in `print` function so that calls to it are logged. diff --git a/logfire/__init__.py b/logfire/__init__.py index 9638f122f..70a159aa4 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -37,6 +37,7 @@ instrument_anthropic = DEFAULT_LOGFIRE_INSTANCE.instrument_anthropic instrument_google_genai = DEFAULT_LOGFIRE_INSTANCE.instrument_google_genai instrument_litellm = DEFAULT_LOGFIRE_INSTANCE.instrument_litellm +instrument_langchain = DEFAULT_LOGFIRE_INSTANCE.instrument_langchain instrument_print = DEFAULT_LOGFIRE_INSTANCE.instrument_print instrument_asyncpg = DEFAULT_LOGFIRE_INSTANCE.instrument_asyncpg instrument_httpx = DEFAULT_LOGFIRE_INSTANCE.instrument_httpx @@ -132,6 +133,7 @@ def loguru_handler() -> Any: 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', + 'instrument_langchain', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', diff --git a/logfire/_internal/exporters/processor_wrapper.py b/logfire/_internal/exporters/processor_wrapper.py index 21da6cff1..d32c1af45 100644 --- a/logfire/_internal/exporters/processor_wrapper.py +++ b/logfire/_internal/exporters/processor_wrapper.py @@ -328,6 +328,64 @@ def _tweak_fastapi_span(span: ReadableSpanDict): span['events'] = new_events[::-1] +def _normalize_content_block(block: dict[str, Any]) -> dict[str, Any]: + """Normalize a content block to OTel GenAI schema. + + Handles: + - Text: converts 'text' field to 'content' (OTel uses 'content') + - tool_use: converts to 'tool_call' (OTel standard) + - tool_result: converts to 'tool_call_response' (OTel standard) + """ + block_type = block.get('type', 'text') + + if block_type == 'text': + return { + 'type': 'text', + 'content': block.get('content', block.get('text', '')), + } + + if block_type == 'tool_use': + return { + 'type': 'tool_call', + 'id': block.get('id'), + 'name': block.get('name'), + 'arguments': block.get('input', block.get('arguments')), + } + + if block_type == 'tool_result': + return { + 'type': 'tool_call_response', + 'id': block.get('tool_use_id', block.get('id')), + 'response': block.get('content', block.get('response')), + } + + return block + + +def _convert_to_otel_message(msg: dict[str, Any]) -> dict[str, Any]: + """Convert a message dict to OTel GenAI message schema with role and parts.""" + otel_msg: dict[str, Any] = {'role': msg.get('role', 'user'), 'parts': []} + content = msg.get('content') + if content: + if isinstance(content, str): + otel_msg['parts'].append({'type': 'text', 'content': content}) + elif isinstance(content, list): + for block in cast(list[Any], content): + if isinstance(block, dict): + otel_msg['parts'].append(_normalize_content_block(cast('dict[str, Any]', block))) + if tool_calls := msg.get('tool_calls'): + for tc in tool_calls: + otel_msg['parts'].append( + { + 'type': 'tool_call', + 'id': tc.get('id'), + 'name': tc.get('function', {}).get('name') or tc.get('name'), + 'arguments': tc.get('function', {}).get('arguments') or tc.get('args'), + } + ) + return otel_msg + + def _transform_langchain_span(span: ReadableSpanDict): """Transform spans generated by LangSmith to work better in the Logfire UI. @@ -387,6 +445,19 @@ def _transform_langchain_span(span: ReadableSpanDict): # Remove gen_ai.system=langchain as this also interferes with costs in the UI. attributes = {k: v for k, v in attributes.items() if k != 'gen_ai.system'} + # Extract finish reason from completion data + with suppress(Exception): + completion = parsed_attributes.get('gen_ai.completion', {}) + stop_reason = ( + completion.get('generations', [[{}]])[0][0] + .get('message', {}) + .get('kwargs', {}) + .get('response_metadata', {}) + .get('stop_reason') + ) + if stop_reason: + new_attributes['gen_ai.response.finish_reasons'] = json.dumps([stop_reason]) + # Add `all_messages_events` with suppress(Exception): input_messages = parsed_attributes.get('input.value', parsed_attributes.get('gen_ai.prompt', {}))['messages'] @@ -422,6 +493,33 @@ def _transform_langchain_span(span: ReadableSpanDict): new_attributes['all_messages_events'] = json.dumps(message_events) properties['all_messages_events'] = {'type': 'array'} + # Extract OTel GenAI formatted messages + input_msgs: list[dict[str, Any]] = [] + output_msgs: list[dict[str, Any]] = [] + system_instructions: list[Any] = [] + for msg in message_events: + role = msg.get('role') + if role == 'system': + content = msg.get('content', '') + if isinstance(content, str): + system_instructions.append({'type': 'text', 'content': content}) + elif isinstance(content, list): + system_instructions.extend(cast(list[Any], content)) + elif role == 'assistant': + output_msgs.append(_convert_to_otel_message(msg)) + else: + input_msgs.append(_convert_to_otel_message(msg)) + + if input_msgs: + new_attributes['gen_ai.input.messages'] = json.dumps(input_msgs) + properties['gen_ai.input.messages'] = {'type': 'array'} + if output_msgs: + new_attributes['gen_ai.output.messages'] = json.dumps(output_msgs) + properties['gen_ai.output.messages'] = {'type': 'array'} + if system_instructions: + new_attributes['gen_ai.system_instructions'] = json.dumps(system_instructions) + properties['gen_ai.system_instructions'] = {'type': 'array'} + span['attributes'] = { **attributes, ATTRIBUTES_JSON_SCHEMA_KEY: attributes_json_schema(properties), diff --git a/logfire/_internal/integrations/langchain.py b/logfire/_internal/integrations/langchain.py new file mode 100644 index 000000000..bbf3bb230 --- /dev/null +++ b/logfire/_internal/integrations/langchain.py @@ -0,0 +1,633 @@ +"""LangChain/LangGraph instrumentation for capturing tool definitions. + +This module provides callback-based instrumentation for LangChain that captures +tool definitions, which are not available through LangSmith's OTEL integration. +""" + +from __future__ import annotations + +import json +from contextlib import AbstractContextManager, contextmanager +from contextvars import Token +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast +from uuid import UUID + +from opentelemetry import context as context_api +from opentelemetry.context import Context +from opentelemetry.trace import SpanKind + +if TYPE_CHECKING: + from ..main import Logfire + +# GenAI semantic convention attribute names (inline to keep LangChain instrumentation self-contained) +OPERATION_NAME = 'gen_ai.operation.name' +REQUEST_MODEL = 'gen_ai.request.model' +RESPONSE_MODEL = 'gen_ai.response.model' +RESPONSE_FINISH_REASONS = 'gen_ai.response.finish_reasons' +INPUT_TOKENS = 'gen_ai.usage.input_tokens' +OUTPUT_TOKENS = 'gen_ai.usage.output_tokens' +INPUT_MESSAGES = 'gen_ai.input.messages' +OUTPUT_MESSAGES = 'gen_ai.output.messages' +SYSTEM_INSTRUCTIONS = 'gen_ai.system_instructions' +TOOL_DEFINITIONS = 'gen_ai.tool.definitions' +CONVERSATION_ID = 'gen_ai.conversation.id' + + +try: + from langchain_core.callbacks.base import BaseCallbackHandler + + _BASE_CLASS: type[Any] = BaseCallbackHandler +except ImportError: + _BASE_CLASS: type[Any] = object # pyright: ignore[reportConstantRedefinition] + + +@dataclass +class SpanWithToken: + """Container for span and its context token.""" + + span: Any + token: Token[Context] | None = None + + +def _detach_span_from_context(token: Token[Context]) -> None: + """Detach span from context using token.""" + try: + context_api.detach(token) + except ValueError: + pass + + +def _normalize_content_block(block: dict[str, Any]) -> dict[str, Any]: + """Normalize a content block to OTel GenAI schema. + + Handles: + - Text: converts 'text' field to 'content' (OTel uses 'content') + - tool_use: converts to 'tool_call' (OTel standard) + - tool_result: converts to 'tool_call_response' (OTel standard) + """ + block_type = block.get('type', 'text') + + if block_type == 'text': + return { + 'type': 'text', + 'content': block.get('content', block.get('text', '')), + } + + if block_type == 'tool_use': + return { + 'type': 'tool_call', + 'id': block.get('id'), + 'name': block.get('name'), + 'arguments': block.get('input', block.get('arguments')), + } + + if block_type == 'tool_result': + return { + 'type': 'tool_call_response', + 'id': block.get('tool_use_id', block.get('id')), + 'response': block.get('content', block.get('response')), + } + + return block + + +class LogfireLangchainCallbackHandler(_BASE_CLASS): + """LangChain callback handler that captures full execution hierarchy. + + This handler captures: + - Chain execution (on_chain_start/end) + - Tool execution (on_tool_start/end) + - Retriever execution (on_retriever_start/end) + - LLM calls with tool definitions (on_chat_model_start/on_llm_start) + + Uses parent_run_id for hierarchy instead of context propagation. + """ + + def __init__(self, logfire: Logfire): + super().__init__() # pyright: ignore[reportUnknownMemberType] + self.run_inline = True + self._logfire = logfire + self._run_span_mapping: dict[str, SpanWithToken] = {} + + def _get_span_by_run_id(self, run_id: UUID) -> Any | None: + """Get span from run_id mapping.""" + if st := self._run_span_mapping.get(str(run_id)): + return st.span + return None + + def _get_parent_span(self, parent_run_id: UUID | None) -> Any | None: + """Get parent span from parent_run_id mapping.""" + if parent_run_id: + if st := self._run_span_mapping.get(str(parent_run_id)): + return st.span + return None + + def _get_span_name(self, serialized: dict[str, Any], default: str = 'unknown') -> str: + """Extract span name from serialized dict.""" + return serialized.get('name', serialized.get('id', [default])[-1]) + + def _extract_conversation_id(self, metadata: dict[str, Any] | None) -> str | None: + """Extract thread_id from metadata for gen_ai.conversation.id.""" + if metadata: + return metadata.get('thread_id') + return None + + def _start_span( + self, + span_name: str, + run_id: UUID, + parent_run_id: UUID | None = None, + span_kind: SpanKind = SpanKind.INTERNAL, + conversation_id: str | None = None, + **span_data: Any, + ) -> Any: + """Start a span with proper parent linkage using parent_run_id.""" + parent_span = self._get_parent_span(parent_run_id) + + parent_token = None + parent_context = parent_span.get_context() if parent_span else None + if parent_context: + parent_token = context_api.attach(parent_context) + + try: + span = self._logfire.span( + span_name, + _span_kind=span_kind, + **span_data, + ) + span.start() + if conversation_id: + span.set_attribute(CONVERSATION_ID, conversation_id) + finally: + if parent_token is not None: + context_api.detach(parent_token) + + self._run_span_mapping[str(run_id)] = SpanWithToken(span, None) + return span + + def _end_span( + self, + run_id: UUID, + outputs: Any = None, + error: BaseException | None = None, + ) -> None: + """End span and clean up mapping.""" + st = self._run_span_mapping.pop(str(run_id), None) + if not st: + return + + try: + if error and st.span.is_recording(): + st.span.record_exception(error, escaped=True) + st.span.end() + finally: + if st.token: + _detach_span_from_context(st.token) + + def _extract_tool_definitions(self, kwargs: dict[str, Any]) -> list[dict[str, Any]]: + """Extract tool definitions from invocation_params.tools.""" + raw_tools = cast(list[Any], kwargs.get('invocation_params', {}).get('tools', [])) + tools: list[dict[str, Any]] = [] + for raw_tool in raw_tools: + if raw_tool.get('type') == 'function': + tools.append(raw_tool) + elif 'name' in raw_tool: + tools.append( + { + 'type': 'function', + 'function': { + 'name': raw_tool.get('name'), + 'description': raw_tool.get('description'), + 'parameters': raw_tool.get('input_schema', raw_tool.get('parameters')), + }, + } + ) + return tools + + def _convert_messages_to_otel(self, messages: list[list[Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Convert LangChain messages to OTel GenAI format.""" + input_msgs: list[dict[str, Any]] = [] + system_instructions: list[dict[str, Any]] = [] + + for msg_list in messages: + for msg in msg_list: + msg_type = getattr(msg, 'type', 'unknown') + content = getattr(msg, 'content', str(msg)) + + if msg_type == 'system': + if isinstance(content, str): + system_instructions.append({'type': 'text', 'content': content}) + elif isinstance(content, list): + for item in cast(list[Any], content): + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + if 'text' in item_dict and 'content' not in item_dict: + system_instructions.append( + { + 'type': item_dict.get('type', 'text'), + 'content': item_dict['text'], + } + ) + else: + system_instructions.append(item_dict) + elif isinstance(item, str): + system_instructions.append({'type': 'text', 'content': item}) + elif msg_type == 'tool': + tool_call_id = getattr(msg, 'tool_call_id', None) + response_content = content if isinstance(content, str) else str(content) + parts: list[dict[str, Any]] = [ + { + 'type': 'tool_call_response', + 'id': tool_call_id, + 'response': response_content, + } + ] + input_msgs.append({'role': 'tool', 'parts': parts}) + else: + otel_role = {'human': 'user', 'ai': 'assistant'}.get(msg_type, msg_type) + parts = [] + + if isinstance(content, str): + parts.append({'type': 'text', 'content': content}) + elif isinstance(content, list): + for item in cast(list[Any], content): + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + if item_dict.get('type') == 'tool_use': + continue + parts.append(_normalize_content_block(item_dict)) + elif isinstance(item, str): + parts.append({'type': 'text', 'content': item}) + + if tool_calls := getattr(msg, 'tool_calls', None): + for tc in cast(list[Any], tool_calls): + if isinstance(tc, dict): + tc_dict = cast(dict[str, Any], tc) + parts.append( + { + 'type': 'tool_call', + 'id': tc_dict.get('id'), + 'name': tc_dict.get('name'), + 'arguments': tc_dict.get('args'), + } + ) + + input_msgs.append({'role': otel_role, 'parts': parts}) + + return input_msgs, system_instructions + + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_type: str | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + """Called when a chain starts - creates parent span for hierarchy.""" + span_name = name or self._get_span_name(serialized, 'chain') + conversation_id = self._extract_conversation_id(metadata) + self._start_span(span_name, run_id, parent_run_id, SpanKind.INTERNAL, conversation_id=conversation_id) + + def on_chain_end( + self, + outputs: dict[str, Any], + *, + run_id: UUID, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Called when a chain ends.""" + self._end_span(run_id) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when a chain errors.""" + self._end_span(run_id, error=error) + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + """Called when a tool starts.""" + span_name = name or self._get_span_name(serialized, 'tool') + conversation_id = self._extract_conversation_id(metadata) + self._start_span(span_name, run_id, parent_run_id, SpanKind.INTERNAL, conversation_id=conversation_id) + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when a tool ends.""" + self._end_span(run_id) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when a tool errors.""" + self._end_span(run_id, error=error) + + def on_retriever_start( + self, + serialized: dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + """Called when a retriever starts.""" + span_name = name or self._get_span_name(serialized, 'retriever') + conversation_id = self._extract_conversation_id(metadata) + self._start_span(span_name, run_id, parent_run_id, SpanKind.INTERNAL, conversation_id=conversation_id) + + def on_retriever_end( + self, + documents: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when a retriever ends.""" + self._end_span(run_id) + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when a retriever errors.""" + self._end_span(run_id, error=error) + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + """Called when a chat model starts - captures tool definitions.""" + invocation_params = kwargs.get('invocation_params', {}) + model = invocation_params.get('model', invocation_params.get('model_name', 'unknown')) + + span_name = name or self._get_span_name(serialized, f'chat {model}') + span_data: dict[str, Any] = { + OPERATION_NAME: 'chat', + REQUEST_MODEL: model, + } + + if tools := self._extract_tool_definitions(kwargs): + span_data[TOOL_DEFINITIONS] = tools + + try: + input_msgs, system_instructions = self._convert_messages_to_otel(messages) + if input_msgs: + span_data[INPUT_MESSAGES] = input_msgs + if system_instructions: + span_data[SYSTEM_INSTRUCTIONS] = system_instructions + except Exception: + pass + + conversation_id = self._extract_conversation_id(metadata) + self._start_span( + span_name, run_id, parent_run_id, SpanKind.CLIENT, conversation_id=conversation_id, **span_data + ) + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + """Called when a non-chat LLM starts.""" + invocation_params = kwargs.get('invocation_params', {}) + model = invocation_params.get('model', invocation_params.get('model_name', 'unknown')) + + span_name = name or self._get_span_name(serialized, f'llm {model}') + span_data: dict[str, Any] = { + OPERATION_NAME: 'completion', + REQUEST_MODEL: model, + } + + if tools := self._extract_tool_definitions(kwargs): + span_data[TOOL_DEFINITIONS] = tools + + conversation_id = self._extract_conversation_id(metadata) + self._start_span( + span_name, run_id, parent_run_id, SpanKind.CLIENT, conversation_id=conversation_id, **span_data + ) + + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when LLM ends.""" + span = self._get_span_by_run_id(run_id) + if not span: + return + + try: + generations = getattr(response, 'generations', [[]]) + if generations and generations[0]: + gen = generations[0][0] + message = getattr(gen, 'message', None) + if message: + response_metadata = getattr(message, 'response_metadata', {}) or {} + if stop_reason := response_metadata.get('stop_reason', response_metadata.get('finish_reason')): + span.set_attribute(RESPONSE_FINISH_REASONS, json.dumps([stop_reason])) + + if model_name := response_metadata.get('model_name', response_metadata.get('model')): + span.set_attribute(RESPONSE_MODEL, model_name) + + content = getattr(message, 'content', '') + output_msg: dict[str, Any] = {'role': 'assistant', 'parts': []} + + if isinstance(content, str): + output_msg['parts'].append({'type': 'text', 'content': content}) + elif isinstance(content, list): + for item in cast(list[Any], content): + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + if item_dict.get('type') == 'tool_use': + continue + output_msg['parts'].append(_normalize_content_block(item_dict)) + elif isinstance(item, str): + output_msg['parts'].append({'type': 'text', 'content': item}) + + if tool_calls := getattr(message, 'tool_calls', None): + for tc in cast(list[Any], tool_calls): + if isinstance(tc, dict): + tc_dict = cast(dict[str, Any], tc) + output_msg['parts'].append( + { + 'type': 'tool_call', + 'id': tc_dict.get('id'), + 'name': tc_dict.get('name'), + 'arguments': tc_dict.get('args'), + } + ) + + span.set_attribute(OUTPUT_MESSAGES, [output_msg]) + + llm_output = cast(dict[str, Any], getattr(response, 'llm_output', {}) or {}) + usage = cast(dict[str, Any], llm_output.get('usage') or llm_output.get('token_usage') or {}) + if input_tokens := usage.get('input_tokens', usage.get('prompt_tokens')): + span.set_attribute(INPUT_TOKENS, input_tokens) + if output_tokens := usage.get('output_tokens', usage.get('completion_tokens')): + span.set_attribute(OUTPUT_TOKENS, output_tokens) + except Exception: + pass + finally: + self._end_span(run_id) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Called when LLM errors.""" + self._end_span(run_id, error=error) + + +_original_callback_manager_init: Any = None +_logfire_instance: Logfire | None = None +_handler_instance: LogfireLangchainCallbackHandler | None = None + + +def _patch_callback_manager(logfire: Logfire) -> None: + """Patch BaseCallbackManager to inject our handler.""" + global _original_callback_manager_init, _logfire_instance, _handler_instance + + try: + from langchain_core.callbacks import BaseCallbackManager + except ImportError as e: + raise ImportError( + 'langchain-core is required for LangChain instrumentation. Install it with: pip install langchain-core' + ) from e + + if _original_callback_manager_init is not None: + return + + _logfire_instance = logfire + _handler_instance = None + _original_callback_manager_init = BaseCallbackManager.__init__ + + def patched_init(self: Any, *args: Any, **kwargs: Any) -> None: + global _handler_instance + _original_callback_manager_init(self, *args, **kwargs) + + for handler in list(getattr(self, 'handlers', [])) + list(getattr(self, 'inheritable_handlers', [])): + if isinstance(handler, LogfireLangchainCallbackHandler): + return + + if _logfire_instance is not None: + if _handler_instance is None: + _handler_instance = LogfireLangchainCallbackHandler(_logfire_instance) + self.add_handler(_handler_instance, inherit=True) + + BaseCallbackManager.__init__ = patched_init + + +def _unpatch_callback_manager() -> None: + """Restore original BaseCallbackManager.__init__.""" + global _original_callback_manager_init, _logfire_instance, _handler_instance + + if _original_callback_manager_init is None: + return + + try: + from langchain_core.callbacks import BaseCallbackManager + + BaseCallbackManager.__init__ = _original_callback_manager_init + except ImportError: + pass + + _original_callback_manager_init = None + _logfire_instance = None + _handler_instance = None + + +def instrument_langchain(logfire: Logfire) -> AbstractContextManager[None]: + """Instrument LangChain to capture full execution hierarchy. + + This patches LangChain's BaseCallbackManager to inject a callback handler + that captures the complete execution hierarchy including chains, tools, + retrievers, and LLMs with tool definitions. + + The patching happens immediately when this function is called. + Returns a context manager that can be used to uninstrument if needed. + + Args: + logfire: The Logfire instance to use for creating spans. + + Returns: + A context manager for optional cleanup/uninstrumentation. + + Example: + ```python + import logfire + + logfire.configure() + logfire.instrument_langchain() + + # Now LangChain operations will be traced with full hierarchy + ``` + """ + _patch_callback_manager(logfire) + + @contextmanager + def cleanup_context(): + try: + yield + finally: + _unpatch_callback_manager() + + return cleanup_context() diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 75ad44384..f9bae7ce2 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -26,7 +26,7 @@ from opentelemetry.context import Context from opentelemetry.metrics import CallbackT, Counter, Histogram, UpDownCounter from opentelemetry.sdk.trace import ReadableSpan, Span -from opentelemetry.trace import SpanContext +from opentelemetry.trace import SpanContext, SpanKind from opentelemetry.util import types as otel_types from typing_extensions import LiteralString, ParamSpec @@ -188,6 +188,7 @@ def _span( _span_name: str | None = None, _level: LevelName | int | None = None, _links: Sequence[tuple[SpanContext, otel_types.Attributes]] = (), + _span_kind: SpanKind | None = None, ) -> LogfireSpan: try: if _level is not None: @@ -243,6 +244,7 @@ def _span( self._spans_tracer, json_schema_properties, links=_links, + span_kind=_span_kind, ) except Exception: log_internal_error() @@ -540,6 +542,7 @@ def span( _span_name: str | None = None, _level: LevelName | None = None, _links: Sequence[tuple[SpanContext, otel_types.Attributes]] = (), + _span_kind: SpanKind | None = None, **attributes: Any, ) -> LogfireSpan: """Context manager for creating a span. @@ -559,6 +562,7 @@ def span( _tags: An optional sequence of tags to include in the span. _level: An optional log level name. _links: An optional sequence of links to other spans. Each link is a tuple of a span context and attributes. + _span_kind: An optional span kind (e.g., SpanKind.CLIENT for external calls). attributes: The arguments to include in the span and format the message template with. Attributes starting with an underscore are not allowed. """ @@ -571,6 +575,7 @@ def span( _span_name=_span_name, _level=_level, _links=_links, + _span_kind=_span_kind, ) @overload @@ -1367,6 +1372,42 @@ def instrument_litellm(self, **kwargs: Any): self._warn_if_not_initialized_for_instrumentation() instrument_litellm(self, **kwargs) + def instrument_langchain(self) -> AbstractContextManager[None]: + """Instrument LangChain to capture full execution hierarchy with tool definitions. + + This patches LangChain's BaseCallbackManager to inject a callback handler + that captures the complete execution hierarchy including chains, tools, + retrievers, and LLM calls with tool definitions. + + The instrumentation complements LangSmith's OTEL integration by adding: + - Tool definitions (gen_ai.tool.definitions) + - Input/output messages in OTel GenAI format + - System instructions + - Conversation tracking via thread_id + + Example usage: + + ```python + import logfire + from langchain_anthropic import ChatAnthropic + from langchain_core.messages import HumanMessage + + logfire.configure() + logfire.instrument_langchain() + + model = ChatAnthropic(model='claude-3-haiku-20240307') + response = model.invoke([HumanMessage(content='Hello!')]) + ``` + + Returns: + A context manager that will revert the instrumentation when exited. + Use of this context manager is optional. + """ + from .integrations.langchain import instrument_langchain + + self._warn_if_not_initialized_for_instrumentation() + return instrument_langchain(self) + def instrument_print(self) -> AbstractContextManager[None]: """Instrument the built-in `print` function so that calls to it are logged. @@ -2382,12 +2423,14 @@ def __init__( tracer: _ProxyTracer, json_schema_properties: JsonSchemaProperties, links: Sequence[tuple[SpanContext, otel_types.Attributes]], + span_kind: SpanKind | None = None, ) -> None: self._span_name = span_name self._otlp_attributes = otlp_attributes self._tracer = tracer self._json_schema_properties = json_schema_properties self._links = list(trace_api.Link(context=context, attributes=attributes) for context, attributes in links) + self._span_kind = span_kind self._added_attributes = False self._token: None | Token[Context] = None @@ -2402,11 +2445,14 @@ def __getattr__(self, name: str) -> Any: def _start(self): if self._span is not None: return - self._span = self._tracer.start_span( - name=self._span_name, - attributes=self._otlp_attributes, - links=self._links, - ) + kwargs: dict[str, Any] = { + 'name': self._span_name, + 'attributes': self._otlp_attributes, + 'links': self._links, + } + if self._span_kind is not None: + kwargs['kind'] = self._span_kind + self._span = self._tracer.start_span(**kwargs) @handle_internal_errors def _attach(self): @@ -2441,6 +2487,49 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseExceptio self._span.record_exception(exc_value, escaped=True) self._end() + def start(self) -> None: + """Start the span without entering a context manager. + + Use this for callback-based instrumentation where spans are started + and ended at different callback points rather than using `with` blocks. + + Note: You must call `end()` to properly close the span. + """ + self._start() + + def end(self) -> None: + """End the span without exiting a context manager. + + Use this for callback-based instrumentation where spans are started + and ended at different callback points rather than using `with` blocks. + + Note: This does NOT detach the span from context. If you attached + the span to context, you must detach it separately. + """ + self._end() + + def get_context(self) -> Context | None: + """Get the OpenTelemetry context with this span. + + Returns the context that can be used to create child spans + with this span as parent. Returns None if span not started. + + Example: + parent_span = logfire.span("parent") + parent_span.start() + if ctx := parent_span.get_context(): + token = context_api.attach(ctx) + try: + # Child spans created here will have parent_span as parent + child_span = logfire.span("child") + ... + finally: + context_api.detach(token) + """ + if self._span is None: + return None + return trace_api.set_span_in_context(self._span) + @property def message_template(self) -> str | None: # pragma: no cover return self._get_attribute(ATTRIBUTES_MESSAGE_TEMPLATE_KEY, None)