Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions src/oumi/core/processors/default_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

import copy
from collections.abc import Callable
from functools import lru_cache
from pathlib import Path
from typing import Any

import jinja2
import PIL.Image
import transformers
from typing_extensions import override
Expand All @@ -29,6 +31,38 @@
from oumi.utils.logging import logger
from oumi.utils.str_utils import truncate_to_max_tokens_limit

_REASONING_FIELD_NAMES = frozenset({"reasoning", "reasoning_content"})
_jinja_env = jinja2.Environment()


@lru_cache(maxsize=8)
def _template_references_reasoning(template: str) -> bool:
"""Check if a Jinja2 template references reasoning fields on messages.

Parses the template AST and walks it to find attribute access
(``message.reasoning_content``) or item access
(``message['reasoning']``) nodes that reference reasoning fields.
This avoids false positives from comments or unrelated text.
"""
from jinja2.nodes import Const, Getattr, Getitem, Node

try:
ast = _jinja_env.parse(template)
except jinja2.TemplateSyntaxError:
return False

pending: list[Node] = [ast]
while pending:
node = pending.pop()
if isinstance(node, Getattr):
if node.attr in _REASONING_FIELD_NAMES:
return True
elif isinstance(node, Getitem):
if isinstance(node.arg, Const) and node.arg.value in _REASONING_FIELD_NAMES:
return True
pending.extend(node.iter_child_nodes())
return False


class DefaultProcessor(BaseProcessor):
"""Default implementation of processor that wraps a worker processor.
Expand Down Expand Up @@ -226,12 +260,48 @@ def __call__(
)
return result

def _template_supports_reasoning(self) -> bool:
"""Check if the chat template natively handles reasoning fields.

Parses the Jinja2 AST to find actual variable references to
``reasoning`` or ``reasoning_content`` on message objects, avoiding
false positives from comments or unrelated text.
"""
return _template_references_reasoning(self.chat_template)

def _convert_messages_to_dicts(self, messages: list[Message]) -> list[dict]:
"""Converts Message objects to dict format for HuggingFace compatibility."""
return [
msg.model_dump(mode="json", exclude_none=True, exclude_unset=True)
for msg in messages
]
"""Converts Message objects to dict format for HuggingFace compatibility.

When a message has ``reasoning_content``, the behavior depends on
whether the chat template natively supports reasoning fields:

- **Template supports reasoning** (e.g., Qwen3): both
``reasoning_content`` and ``reasoning`` keys are included so the
template can render them natively (typically inside ``<think>``
tags).
- **Template does not support reasoning** (e.g., Llama, DeepSeek):
reasoning is prepended to ``content`` wrapped in ``<think>`` tags
so it appears in the tokenized sequence.
"""
template_supports = self._template_supports_reasoning()
result = []
for msg in messages:
d = msg.model_dump(mode="json", exclude_none=True, exclude_unset=True)
if msg.reasoning_content is not None:
if template_supports:
# Pass as separate keys — template handles formatting.
if "reasoning" not in d:
d["reasoning"] = msg.reasoning_content
else:
# Fold into content — template doesn't know about
# reasoning, so we prepend it with <think> tags.
d.pop("reasoning_content", None)
content = d.get("content", "")
d["content"] = (
f"<think>\n{msg.reasoning_content}\n</think>\n\n{content}"
)
Comment thread
jgreer013 marked this conversation as resolved.
result.append(d)
return result

@override
def apply_chat_template(
Expand Down
17 changes: 17 additions & 0 deletions src/oumi/core/types/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ class Message(pydantic.BaseModel):
role: Role
"""The role of the entity sending the message (e.g., user, assistant, system)."""

reasoning_content: str | None = None
"""Optional reasoning/thinking content from the model.

Some providers (e.g., Together) return a separate ``reasoning`` or
``reasoning_content`` field for thinking models (Qwen3.x, Qwen3.5-x).
When present, ``content`` holds the visible output and
``reasoning_content`` holds the internal chain-of-thought. If the model
exhausts its token budget on reasoning, ``content`` may be empty while
``reasoning_content`` contains the partial thinking.

Named ``reasoning_content`` to match Qwen3's chat template convention.
The inference engines handle extracting from both ``reasoning`` and
``reasoning_content`` API response fields. When passed to chat
templates, both key names are included in the message dict so that
any template can find it.
"""

def model_post_init(self, __context) -> None:
"""Post-initialization method for the Message model.

Expand Down
19 changes: 18 additions & 1 deletion src/oumi/inference/anthropic_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,26 @@ def _convert_api_output_to_conversation(
f"model={response.get('model')}, "
f"usage={response.get('usage')}"
)
# Thinking models include {"type": "thinking", "thinking": "..."}
# blocks alongside text blocks. Blocks without a "type" key are
# treated as text (backward compat with older response formats).
text_parts: list[str] = []
reasoning_parts: list[str] = []
for block in content_blocks:
if block.get("type") == "thinking":
thinking = block.get("thinking")
if thinking is not None and thinking != "":
reasoning_parts.append(thinking)
elif block.get("type") == "text" or "text" in block:
part = block.get("text")
if part is not None and part != "":
text_parts.append(part)
text = "".join(text_parts)
reasoning = "".join(reasoning_parts) if reasoning_parts else None
new_message = Message(
content=content_blocks[0]["text"],
content=text,
role=Role.ASSISTANT,
reasoning_content=reasoning,
)
metadata = dict(original_conversation.metadata)
usage = self._extract_usage_from_response(response)
Expand Down
19 changes: 15 additions & 4 deletions src/oumi/inference/bedrock_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,23 @@ def _extract_finish_reason_from_response(
def _convert_api_output_to_conversation(
self, response: dict[str, Any], original: Conversation
) -> Conversation:
text = ""
text_parts: list[str] = []
reasoning_parts: list[str] = []
msg = response.get("output", {}).get("message", {})
for block in msg.get("content", []):
if "text" in block:
text += block["text"]
new_message = Message(content=text, role=Role.ASSISTANT)
if block.get("type") == "thinking":
thinking = block.get("thinking")
if thinking is not None and thinking != "":
reasoning_parts.append(thinking)
elif "text" in block:
Comment thread
jgreer013 marked this conversation as resolved.
part = block.get("text")
if part is not None and part != "":
text_parts.append(part)
text = "".join(text_parts)
reasoning = "".join(reasoning_parts) if reasoning_parts else None
new_message = Message(
content=text, role=Role.ASSISTANT, reasoning_content=reasoning
)
metadata = dict(original.metadata)
finish_reason = self._extract_finish_reason_from_response(response)
if finish_reason is not None:
Expand Down
13 changes: 13 additions & 0 deletions src/oumi/inference/remote_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,25 @@ def _convert_api_output_to_conversation(
finish_reason = self._extract_finish_reason_from_response(response)
if finish_reason is not None:
metadata["finish_reason"] = finish_reason.value
# Some providers (e.g., Together) return content=null for thinking
# models when the token budget is consumed by reasoning. Default to
# empty string so Message validation doesn't fail.
content = message.get("content")
if content is None:
content = ""
# Providers use "reasoning" (Together, Kimi, DeepSeek V3.1) or
# "reasoning_content" (GLM-5) for the thinking chain.
reasoning = message.get("reasoning")
if reasoning is None:
reasoning = message.get("reasoning_content")

return Conversation(
messages=[
*original_conversation.messages,
Message(
content=content,
role=Role(message["role"]),
reasoning_content=reasoning,
),
],
metadata=metadata,
Expand Down
9 changes: 8 additions & 1 deletion src/oumi/inference/sambanova_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,16 @@ def _convert_api_output_to_conversation(
if not message:
raise RuntimeError("No message found in API response")

content = message.get("content")
if content is None:
content = ""
reasoning = message.get("reasoning")
if reasoning is None:
reasoning = message.get("reasoning_content")
new_message = Message(
content=message.get("content", ""),
content=content,
role=Role.ASSISTANT,
reasoning_content=reasoning,
)

return Conversation(
Expand Down
44 changes: 40 additions & 4 deletions src/oumi/utils/conversation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def create_list_of_message_json_dicts(
messages: list[Message],
*,
group_adjacent_same_role_turns: bool,
include_reasoning: bool = False,
) -> list[dict[str, Any]]:
"""Returns a list of JSON dictionaries representing messages.

Expand All @@ -227,6 +228,11 @@ def create_list_of_message_json_dicts(
messages: The input messages.
group_adjacent_same_role_turns: Whether to pack adjacent messages
from the same role into a single element in output list.
include_reasoning: Whether to include ``reasoning_content`` and
``reasoning`` keys in the output dicts. Set to ``True`` when
building dicts for chat templates (training). Set to ``False``
(default) when building dicts for API requests, since some
providers (e.g., Anthropic) reject unknown fields.

Returns:
list[Dict[str, Any]]: The list of messages encoded as nested JSON dicts.
Expand All @@ -235,6 +241,7 @@ def create_list_of_message_json_dicts(
result = []
idx = 0
while idx < num_messages:
start_idx = idx
end_idx = idx + 1
if group_adjacent_same_role_turns:
while end_idx < num_messages and (
Expand All @@ -258,6 +265,21 @@ def create_list_of_message_json_dicts(
idx += 1
item["content"] = content_list

if include_reasoning:
# Collect reasoning from all messages in the group and
# concatenate. Include under both key names so that any chat
# template can find it (Qwen3 uses "reasoning_content", others
# may use "reasoning").
reasoning_parts: list[str] = [
rc
for i in range(start_idx, end_idx)
if (rc := messages[i].reasoning_content) is not None
]
if reasoning_parts:
reasoning = "".join(reasoning_parts)
item["reasoning_content"] = reasoning
item["reasoning"] = reasoning

idx = end_idx
result.append(item)

Expand Down Expand Up @@ -308,12 +330,20 @@ def remove_excessive_images(
if len(filtered_items) == 1 and isinstance(filtered_items[0].content, str):
result.append(
Message(
id=message.id, content=filtered_items[0].content, role=message.role
id=message.id,
content=filtered_items[0].content,
role=message.role,
reasoning_content=message.reasoning_content,
)
)
else:
result.append(
Message(id=message.id, content=filtered_items, role=message.role)
Message(
id=message.id,
content=filtered_items,
role=message.role,
reasoning_content=message.reasoning_content,
)
)

return result
Expand Down Expand Up @@ -425,11 +455,17 @@ def truncate_text_in_content_items(
):
assert isinstance(items[0].content, str)
result[msg_idx] = Message(
id=message.id, content=items[0].content, role=message.role
id=message.id,
content=items[0].content,
role=message.role,
reasoning_content=message.reasoning_content,
)
else:
result[msg_idx] = Message(
id=message.id, content=items, role=message.role
id=message.id,
content=items,
role=message.role,
reasoning_content=message.reasoning_content,
)

return result
Loading
Loading