diff --git a/src/oumi/core/processors/default_processor.py b/src/oumi/core/processors/default_processor.py index b7e5c4bf79..06749a7c35 100644 --- a/src/oumi/core/processors/default_processor.py +++ b/src/oumi/core/processors/default_processor.py @@ -14,9 +14,11 @@ import copy from collections.abc import Callable +from functools import lru_cache from pathlib import Path from typing import Any +import jinja2 import PIL.Image import transformers from typing_extensions import override @@ -29,6 +31,38 @@ from oumi.utils.logging import logger from oumi.utils.str_utils import truncate_to_max_tokens_limit +_REASONING_FIELD_NAMES = frozenset({"reasoning", "reasoning_content"}) +_jinja_env = jinja2.Environment() + + +@lru_cache(maxsize=8) +def _template_references_reasoning(template: str) -> bool: + """Check if a Jinja2 template references reasoning fields on messages. + + Parses the template AST and walks it to find attribute access + (``message.reasoning_content``) or item access + (``message['reasoning']``) nodes that reference reasoning fields. + This avoids false positives from comments or unrelated text. + """ + from jinja2.nodes import Const, Getattr, Getitem, Node + + try: + ast = _jinja_env.parse(template) + except jinja2.TemplateSyntaxError: + return False + + pending: list[Node] = [ast] + while pending: + node = pending.pop() + if isinstance(node, Getattr): + if node.attr in _REASONING_FIELD_NAMES: + return True + elif isinstance(node, Getitem): + if isinstance(node.arg, Const) and node.arg.value in _REASONING_FIELD_NAMES: + return True + pending.extend(node.iter_child_nodes()) + return False + class DefaultProcessor(BaseProcessor): """Default implementation of processor that wraps a worker processor. @@ -226,12 +260,48 @@ def __call__( ) return result + def _template_supports_reasoning(self) -> bool: + """Check if the chat template natively handles reasoning fields. + + Parses the Jinja2 AST to find actual variable references to + ``reasoning`` or ``reasoning_content`` on message objects, avoiding + false positives from comments or unrelated text. + """ + return _template_references_reasoning(self.chat_template) + def _convert_messages_to_dicts(self, messages: list[Message]) -> list[dict]: - """Converts Message objects to dict format for HuggingFace compatibility.""" - return [ - msg.model_dump(mode="json", exclude_none=True, exclude_unset=True) - for msg in messages - ] + """Converts Message objects to dict format for HuggingFace compatibility. + + When a message has ``reasoning_content``, the behavior depends on + whether the chat template natively supports reasoning fields: + + - **Template supports reasoning** (e.g., Qwen3): both + ``reasoning_content`` and ``reasoning`` keys are included so the + template can render them natively (typically inside ```` + tags). + - **Template does not support reasoning** (e.g., Llama, DeepSeek): + reasoning is prepended to ``content`` wrapped in ```` tags + so it appears in the tokenized sequence. + """ + template_supports = self._template_supports_reasoning() + result = [] + for msg in messages: + d = msg.model_dump(mode="json", exclude_none=True, exclude_unset=True) + if msg.reasoning_content is not None: + if template_supports: + # Pass as separate keys — template handles formatting. + if "reasoning" not in d: + d["reasoning"] = msg.reasoning_content + else: + # Fold into content — template doesn't know about + # reasoning, so we prepend it with tags. + d.pop("reasoning_content", None) + content = d.get("content", "") + d["content"] = ( + f"\n{msg.reasoning_content}\n\n\n{content}" + ) + result.append(d) + return result @override def apply_chat_template( diff --git a/src/oumi/core/types/conversation.py b/src/oumi/core/types/conversation.py index 22b090116f..095b7b6b89 100644 --- a/src/oumi/core/types/conversation.py +++ b/src/oumi/core/types/conversation.py @@ -239,6 +239,23 @@ class Message(pydantic.BaseModel): role: Role """The role of the entity sending the message (e.g., user, assistant, system).""" + reasoning_content: str | None = None + """Optional reasoning/thinking content from the model. + + Some providers (e.g., Together) return a separate ``reasoning`` or + ``reasoning_content`` field for thinking models (Qwen3.x, Qwen3.5-x). + When present, ``content`` holds the visible output and + ``reasoning_content`` holds the internal chain-of-thought. If the model + exhausts its token budget on reasoning, ``content`` may be empty while + ``reasoning_content`` contains the partial thinking. + + Named ``reasoning_content`` to match Qwen3's chat template convention. + The inference engines handle extracting from both ``reasoning`` and + ``reasoning_content`` API response fields. When passed to chat + templates, both key names are included in the message dict so that + any template can find it. + """ + def model_post_init(self, __context) -> None: """Post-initialization method for the Message model. diff --git a/src/oumi/inference/anthropic_inference_engine.py b/src/oumi/inference/anthropic_inference_engine.py index 14f304170e..05af32eeed 100644 --- a/src/oumi/inference/anthropic_inference_engine.py +++ b/src/oumi/inference/anthropic_inference_engine.py @@ -197,9 +197,26 @@ def _convert_api_output_to_conversation( f"model={response.get('model')}, " f"usage={response.get('usage')}" ) + # Thinking models include {"type": "thinking", "thinking": "..."} + # blocks alongside text blocks. Blocks without a "type" key are + # treated as text (backward compat with older response formats). + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content_blocks: + if block.get("type") == "thinking": + thinking = block.get("thinking") + if thinking is not None and thinking != "": + reasoning_parts.append(thinking) + elif block.get("type") == "text" or "text" in block: + part = block.get("text") + if part is not None and part != "": + text_parts.append(part) + text = "".join(text_parts) + reasoning = "".join(reasoning_parts) if reasoning_parts else None new_message = Message( - content=content_blocks[0]["text"], + content=text, role=Role.ASSISTANT, + reasoning_content=reasoning, ) metadata = dict(original_conversation.metadata) usage = self._extract_usage_from_response(response) diff --git a/src/oumi/inference/bedrock_inference_engine.py b/src/oumi/inference/bedrock_inference_engine.py index 09604afe94..ed2ce9fe9b 100644 --- a/src/oumi/inference/bedrock_inference_engine.py +++ b/src/oumi/inference/bedrock_inference_engine.py @@ -357,12 +357,23 @@ def _extract_finish_reason_from_response( def _convert_api_output_to_conversation( self, response: dict[str, Any], original: Conversation ) -> Conversation: - text = "" + text_parts: list[str] = [] + reasoning_parts: list[str] = [] msg = response.get("output", {}).get("message", {}) for block in msg.get("content", []): - if "text" in block: - text += block["text"] - new_message = Message(content=text, role=Role.ASSISTANT) + if block.get("type") == "thinking": + thinking = block.get("thinking") + if thinking is not None and thinking != "": + reasoning_parts.append(thinking) + elif "text" in block: + part = block.get("text") + if part is not None and part != "": + text_parts.append(part) + text = "".join(text_parts) + reasoning = "".join(reasoning_parts) if reasoning_parts else None + new_message = Message( + content=text, role=Role.ASSISTANT, reasoning_content=reasoning + ) metadata = dict(original.metadata) finish_reason = self._extract_finish_reason_from_response(response) if finish_reason is not None: diff --git a/src/oumi/inference/remote_inference_engine.py b/src/oumi/inference/remote_inference_engine.py index 58e1c8d635..4354c0023c 100644 --- a/src/oumi/inference/remote_inference_engine.py +++ b/src/oumi/inference/remote_inference_engine.py @@ -487,12 +487,25 @@ def _convert_api_output_to_conversation( finish_reason = self._extract_finish_reason_from_response(response) if finish_reason is not None: metadata["finish_reason"] = finish_reason.value + # Some providers (e.g., Together) return content=null for thinking + # models when the token budget is consumed by reasoning. Default to + # empty string so Message validation doesn't fail. + content = message.get("content") + if content is None: + content = "" + # Providers use "reasoning" (Together, Kimi, DeepSeek V3.1) or + # "reasoning_content" (GLM-5) for the thinking chain. + reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") + return Conversation( messages=[ *original_conversation.messages, Message( content=content, role=Role(message["role"]), + reasoning_content=reasoning, ), ], metadata=metadata, diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py index d01bcebed5..2b5689f830 100644 --- a/src/oumi/inference/sambanova_inference_engine.py +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -109,9 +109,16 @@ def _convert_api_output_to_conversation( if not message: raise RuntimeError("No message found in API response") + content = message.get("content") + if content is None: + content = "" + reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") new_message = Message( - content=message.get("content", ""), + content=content, role=Role.ASSISTANT, + reasoning_content=reasoning, ) return Conversation( diff --git a/src/oumi/utils/conversation_utils.py b/src/oumi/utils/conversation_utils.py index 5dc2b191d6..29377e1b9e 100644 --- a/src/oumi/utils/conversation_utils.py +++ b/src/oumi/utils/conversation_utils.py @@ -218,6 +218,7 @@ def create_list_of_message_json_dicts( messages: list[Message], *, group_adjacent_same_role_turns: bool, + include_reasoning: bool = False, ) -> list[dict[str, Any]]: """Returns a list of JSON dictionaries representing messages. @@ -227,6 +228,11 @@ def create_list_of_message_json_dicts( messages: The input messages. group_adjacent_same_role_turns: Whether to pack adjacent messages from the same role into a single element in output list. + include_reasoning: Whether to include ``reasoning_content`` and + ``reasoning`` keys in the output dicts. Set to ``True`` when + building dicts for chat templates (training). Set to ``False`` + (default) when building dicts for API requests, since some + providers (e.g., Anthropic) reject unknown fields. Returns: list[Dict[str, Any]]: The list of messages encoded as nested JSON dicts. @@ -235,6 +241,7 @@ def create_list_of_message_json_dicts( result = [] idx = 0 while idx < num_messages: + start_idx = idx end_idx = idx + 1 if group_adjacent_same_role_turns: while end_idx < num_messages and ( @@ -258,6 +265,21 @@ def create_list_of_message_json_dicts( idx += 1 item["content"] = content_list + if include_reasoning: + # Collect reasoning from all messages in the group and + # concatenate. Include under both key names so that any chat + # template can find it (Qwen3 uses "reasoning_content", others + # may use "reasoning"). + reasoning_parts: list[str] = [ + rc + for i in range(start_idx, end_idx) + if (rc := messages[i].reasoning_content) is not None + ] + if reasoning_parts: + reasoning = "".join(reasoning_parts) + item["reasoning_content"] = reasoning + item["reasoning"] = reasoning + idx = end_idx result.append(item) @@ -308,12 +330,20 @@ def remove_excessive_images( if len(filtered_items) == 1 and isinstance(filtered_items[0].content, str): result.append( Message( - id=message.id, content=filtered_items[0].content, role=message.role + id=message.id, + content=filtered_items[0].content, + role=message.role, + reasoning_content=message.reasoning_content, ) ) else: result.append( - Message(id=message.id, content=filtered_items, role=message.role) + Message( + id=message.id, + content=filtered_items, + role=message.role, + reasoning_content=message.reasoning_content, + ) ) return result @@ -425,11 +455,17 @@ def truncate_text_in_content_items( ): assert isinstance(items[0].content, str) result[msg_idx] = Message( - id=message.id, content=items[0].content, role=message.role + id=message.id, + content=items[0].content, + role=message.role, + reasoning_content=message.reasoning_content, ) else: result[msg_idx] = Message( - id=message.id, content=items, role=message.role + id=message.id, + content=items, + role=message.role, + reasoning_content=message.reasoning_content, ) return result diff --git a/tests/unit/builders/test_processors.py b/tests/unit/builders/test_processors.py index 5bb7033931..eb909ba86c 100644 --- a/tests/unit/builders/test_processors.py +++ b/tests/unit/builders/test_processors.py @@ -338,3 +338,134 @@ def test_processor_apply_chat_template_multimodal_text_content(): prompt = processor.apply_chat_template(messages) assert isinstance(prompt, str) assert "Describe the following:" in prompt + + +def test_convert_messages_reasoning_with_supporting_template(): + """When chat template references reasoning, pass as separate keys.""" + from unittest.mock import MagicMock, PropertyMock + + from oumi.core.processors.default_processor import DefaultProcessor + + processor = MagicMock(spec=DefaultProcessor) + processor._convert_messages_to_dicts = ( + DefaultProcessor._convert_messages_to_dicts.__get__(processor) + ) + processor._template_supports_reasoning = ( + DefaultProcessor._template_supports_reasoning.__get__(processor) + ) + type(processor).chat_template = PropertyMock( + return_value="{% if message.reasoning_content %}...{% endif %}" + ) + + messages = [ + Message( + role=Role.ASSISTANT, + content="The answer is 4.", + reasoning_content="2+2=4", + ), + ] + result = processor._convert_messages_to_dicts(messages) + assert result[0]["reasoning_content"] == "2+2=4" + assert result[0]["reasoning"] == "2+2=4" + assert result[0]["content"] == "The answer is 4." + + +def test_convert_messages_reasoning_with_unsupporting_template(): + """When chat template doesn't reference reasoning, fold into content.""" + from unittest.mock import MagicMock, PropertyMock + + from oumi.core.processors.default_processor import DefaultProcessor + + processor = MagicMock(spec=DefaultProcessor) + processor._convert_messages_to_dicts = ( + DefaultProcessor._convert_messages_to_dicts.__get__(processor) + ) + processor._template_supports_reasoning = ( + DefaultProcessor._template_supports_reasoning.__get__(processor) + ) + type(processor).chat_template = PropertyMock( + return_value="{% for message in messages %}{{ message.content }}{% endfor %}" + ) + + messages = [ + Message( + role=Role.ASSISTANT, + content="The answer is 4.", + reasoning_content="2+2=4", + ), + ] + result = processor._convert_messages_to_dicts(messages) + assert "reasoning_content" not in result[0] + assert "reasoning" not in result[0] + assert result[0]["content"] == "\n2+2=4\n\n\nThe answer is 4." + + +def test_convert_messages_no_reasoning(): + """When message has no reasoning, no reasoning keys in output.""" + from unittest.mock import MagicMock, PropertyMock + + from oumi.core.processors.default_processor import DefaultProcessor + + processor = MagicMock(spec=DefaultProcessor) + processor._convert_messages_to_dicts = ( + DefaultProcessor._convert_messages_to_dicts.__get__(processor) + ) + processor._template_supports_reasoning = ( + DefaultProcessor._template_supports_reasoning.__get__(processor) + ) + type(processor).chat_template = PropertyMock(return_value="simple template") + + messages = [ + Message(role=Role.ASSISTANT, content="Hello"), + ] + result = processor._convert_messages_to_dicts(messages) + assert "reasoning_content" not in result[0] + assert "reasoning" not in result[0] + assert result[0]["content"] == "Hello" + + +def test_template_supports_reasoning_false_positive_comment(): + """Template with 'reasoning' in a comment should NOT match.""" + from oumi.core.processors.default_processor import _template_references_reasoning + + template = """{# This template handles reasoning models #} +{% for message in messages %}{{ message.content }}{% endfor %}""" + assert _template_references_reasoning(template) is False + + +def test_template_supports_reasoning_false_positive_string_literal(): + """Template with 'reasoning' in a string literal should NOT match.""" + from oumi.core.processors.default_processor import _template_references_reasoning + + template = ( + "{% for message in messages %}" + '{{ "reasoning_content is not used" }}' + "{{ message.content }}{% endfor %}" + ) + assert _template_references_reasoning(template) is False + + +def test_template_supports_reasoning_attribute_access(): + """Template with message.reasoning_content attribute access should match.""" + from oumi.core.processors.default_processor import _template_references_reasoning + + template = ( + "{% if message.reasoning_content %}{{ message.reasoning_content }}{% endif %}" + ) + assert _template_references_reasoning(template) is True + + +def test_template_supports_reasoning_dict_access(): + """Template with message['reasoning'] dict access should match.""" + from oumi.core.processors.default_processor import _template_references_reasoning + + template = """{% if message['reasoning'] %}{{ message['reasoning'] }}{% endif %}""" + assert _template_references_reasoning(template) is True + + +def test_template_supports_reasoning_no_reasoning(): + """Template with no reasoning references should not match.""" + from oumi.core.processors.default_processor import _template_references_reasoning + + template = """{% for m in messages %}{{ m.content }}{% endfor %}""" + assert _template_references_reasoning(template) is False diff --git a/tests/unit/inference/test_anthropic_inference_engine.py b/tests/unit/inference/test_anthropic_inference_engine.py index edf0124142..acfa1bbac0 100644 --- a/tests/unit/inference/test_anthropic_inference_engine.py +++ b/tests/unit/inference/test_anthropic_inference_engine.py @@ -97,6 +97,98 @@ def test_convert_api_output_missing_content_key(anthropic_engine): ) +def test_convert_api_output_interleaved_blocks(anthropic_engine): + """Test that interleaved thinking and text blocks are all collected.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "content": [ + {"type": "thinking", "thinking": "Step 1..."}, + {"type": "text", "text": "First part."}, + {"type": "thinking", "thinking": "Step 2..."}, + {"type": "text", "text": "Second part."}, + ] + } + result = anthropic_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "First part.Second part." + assert result.messages[-1].reasoning_content == "Step 1...Step 2..." + + +def test_convert_api_output_with_thinking_blocks(anthropic_engine): + """Test that thinking blocks are extracted into Message.reasoning_content.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "content": [ + { + "type": "thinking", + "thinking": "Let me reason about this...", + "signature": "abc123", + }, + {"type": "text", "text": "Here is my answer."}, + ] + } + result = anthropic_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "Here is my answer." + assert result.messages[-1].reasoning_content == "Let me reason about this..." + + +def test_convert_api_output_no_thinking_blocks(anthropic_engine): + """Test that reasoning is None when no thinking blocks present.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = {"content": [{"type": "text", "text": "Simple answer."}]} + result = anthropic_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "Simple answer." + assert result.messages[-1].reasoning_content is None + + +def test_convert_api_output_null_thinking_content(anthropic_engine): + """Test that thinking blocks with null/empty thinking are skipped.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "content": [ + {"type": "thinking", "thinking": None, "signature": "sig1"}, + {"type": "thinking", "thinking": "", "signature": "sig2"}, + {"type": "text", "text": "Answer."}, + ] + } + result = anthropic_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "Answer." + assert result.messages[-1].reasoning_content is None + + +def test_convert_api_output_null_text_content(anthropic_engine): + """Test that text=null defaults to empty string.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "content": [ + {"type": "thinking", "thinking": "Reasoning here.", "signature": "sig"}, + {"type": "text", "text": None}, + ] + } + result = anthropic_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "" + assert result.messages[-1].reasoning_content == "Reasoning here." + + @pytest.mark.parametrize( "api_usage,expected_usage", [ diff --git a/tests/unit/inference/test_bedrock_inference_engine.py b/tests/unit/inference/test_bedrock_inference_engine.py index a3f8d5c2d1..928f27aca9 100644 --- a/tests/unit/inference/test_bedrock_inference_engine.py +++ b/tests/unit/inference/test_bedrock_inference_engine.py @@ -101,6 +101,42 @@ def test_convert_api_output_to_conversation(bedrock_engine): assert result.conversation_id == "test_id" +@pytest.mark.skipif(boto3_import_failed, reason="boto3 not available") +def test_convert_api_output_with_thinking_blocks(bedrock_engine): + """Test that thinking blocks are extracted into Message.reasoning_content.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "output": { + "message": { + "content": [ + { + "type": "thinking", + "thinking": "Step 1: analyze...", + }, + {"text": "The answer is 42."}, + ] + } + } + } + result = bedrock_engine._convert_api_output_to_conversation(api_response, original) + assert result.messages[-1].content == "The answer is 42." + assert result.messages[-1].reasoning_content == "Step 1: analyze..." + + +@pytest.mark.skipif(boto3_import_failed, reason="boto3 not available") +def test_convert_api_output_no_thinking_blocks(bedrock_engine): + """Test that reasoning is None when no thinking blocks present.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = {"output": {"message": {"content": [{"text": "Simple answer."}]}}} + result = bedrock_engine._convert_api_output_to_conversation(api_response, original) + assert result.messages[-1].content == "Simple answer." + assert result.messages[-1].reasoning_content is None + + @pytest.mark.skipif(boto3_import_failed, reason="boto3 not available") def test_infer_online(bedrock_engine): with patch.object(bedrock_engine, "_infer") as mock_infer: diff --git a/tests/unit/inference/test_remote_inference_engine.py b/tests/unit/inference/test_remote_inference_engine.py index d1673d3e68..0816243bd3 100644 --- a/tests/unit/inference/test_remote_inference_engine.py +++ b/tests/unit/inference/test_remote_inference_engine.py @@ -2109,6 +2109,186 @@ def test_convert_api_output_content_null_returns_empty_string(): assert result.messages[-1].role == Role.ASSISTANT +def test_convert_api_output_null_content_with_reasoning(): + """Test that null content defaults to empty string and reasoning is on message. + + Thinking models (e.g., Qwen3.5) via Together API can return content=null + when the token budget is consumed by reasoning. + """ + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "reasoning": "Thinking about this...", + }, + "finish_reason": "length", + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].content == "" + assert result.messages[-1].role == Role.ASSISTANT + assert result.messages[-1].reasoning_content == "Thinking about this..." + + +def test_convert_api_output_null_content_no_reasoning(): + """Test that null content with no reasoning falls back to empty string.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].content == "" + assert result.messages[-1].role == Role.ASSISTANT + assert result.messages[-1].reasoning_content is None + + +def test_convert_api_output_null_reasoning(): + """Test that reasoning=null in response results in None, not empty string.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hi", + "reasoning": None, + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].content == "Hi" + assert result.messages[-1].reasoning_content is None + + +def test_convert_api_output_empty_string_reasoning(): + """Test that empty reasoning is preserved, not treated as None.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hi", + "reasoning": "", + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].content == "Hi" + assert result.messages[-1].reasoning_content == "" + + +def test_convert_api_output_reasoning_content_fallback(): + """Test that reasoning_content is used when reasoning is absent.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Answer", + "reasoning_content": "GLM thinking...", + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].reasoning_content == "GLM thinking..." + + +def test_convert_api_output_reasoning_preferred_over_reasoning_content(): + """Test that reasoning takes precedence over reasoning_content when both present.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Answer", + "reasoning": "primary", + "reasoning_content": "fallback", + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].reasoning_content == "primary" + + +def test_convert_api_output_content_with_reasoning(): + """Test that when both content and reasoning are present, both are on message.""" + engine = RemoteInferenceEngine( + _get_default_model_params(), + remote_params=RemoteParams(api_url=_TARGET_SERVER), + ) + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "The answer is B.", + "reasoning": "Let me think step by step...", + }, + } + ], + } + result = engine._convert_api_output_to_conversation(response, original) + assert result.messages[-1].content == "The answer is B." + assert result.messages[-1].reasoning_content == "Let me think step by step..." + + @pytest.mark.asyncio async def test_upload_batch_file(): """Test uploading a batch file.""" diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py index cdaa00630d..f2f5001072 100644 --- a/tests/unit/inference/test_sambanova_inference_engine.py +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -76,6 +76,52 @@ def test_convert_api_output_to_conversation(sambanova_engine): assert result.conversation_id == "test_id" +def test_convert_api_output_with_reasoning(sambanova_engine): + """Test that reasoning field is extracted into Message.reasoning_content.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "choices": [ + { + "message": { + "content": "The answer is B.", + "role": "assistant", + "reasoning": "Let me think step by step...", + } + } + ] + } + result = sambanova_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "The answer is B." + assert result.messages[-1].reasoning_content == "Let me think step by step..." + + +def test_convert_api_output_null_content_with_reasoning(sambanova_engine): + """Test that null content defaults to empty string with reasoning preserved.""" + original = Conversation( + messages=[Message(content="Hello", role=Role.USER)], + ) + api_response = { + "choices": [ + { + "message": { + "content": None, + "role": "assistant", + "reasoning": "Still thinking...", + } + } + ] + } + result = sambanova_engine._convert_api_output_to_conversation( + api_response, original + ) + assert result.messages[-1].content == "" + assert result.messages[-1].reasoning_content == "Still thinking..." + + def test_convert_api_output_to_conversation_error_handling(sambanova_engine): """Test error handling in API output conversion.""" original_conversation = Conversation( diff --git a/tests/unit/utils/test_conversation_utils.py b/tests/unit/utils/test_conversation_utils.py index c36655dec3..e8e41956b2 100644 --- a/tests/unit/utils/test_conversation_utils.py +++ b/tests/unit/utils/test_conversation_utils.py @@ -18,6 +18,7 @@ create_list_of_message_json_dicts, load_image_bytes_to_content_item, load_pil_image_from_content_item, + remove_excessive_images, remove_excessive_images_from_conversation, truncate_text_in_content_items, ) @@ -802,3 +803,46 @@ def test_truncate_text_in_content_items( assert truncated_messages == expected_messages else: assert truncated_messages == messages + + +def test_remove_excessive_images_preserves_reasoning(): + """Test that reasoning is preserved when images are removed.""" + messages = [ + Message( + content=[ + ContentItem(type=Type.IMAGE_URL, content="http://img1.png"), + ContentItem(type=Type.IMAGE_URL, content="http://img2.png"), + ContentItem(type=Type.TEXT, content="describe these"), + ], + role=Role.USER, + reasoning_content="User-side reasoning", + ), + ] + result = remove_excessive_images(messages, max_images=1) + # Message gets reconstructed with filtered items; reasoning must survive + assert result[0].reasoning_content == "User-side reasoning" + + +def test_truncate_text_preserves_reasoning(gpt2_tokenizer): + """Test that reasoning is preserved when text is truncated.""" + messages = [ + Message( + content="This is a very long message that should get truncated", + role=Role.ASSISTANT, + reasoning_content="My reasoning process", + ), + ] + result = truncate_text_in_content_items(messages, gpt2_tokenizer, max_tokens=3) + assert result[0].content != messages[0].content + assert result[0].reasoning_content == "My reasoning process" + + +def test_message_reasoning_content_serialization(): + """Test that reasoning_content is included in model_dump when set.""" + m = Message(content="answer", role=Role.ASSISTANT, reasoning_content="thinking") + d = m.model_dump(exclude_none=True) + assert d["reasoning_content"] == "thinking" + + m_none = Message(content="answer", role=Role.ASSISTANT) + d_none = m_none.model_dump(exclude_none=True) + assert "reasoning_content" not in d_none