Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/oumi/builders/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from oumi.core.collators.text_completions_collator_with_padding import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it inherits padding support from the base class, should we add "WithPadding" to its name? Or are we moving away from that naming?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify this class actually uses composition rather than inheritance (it wraps a transformers DataCollatorForCompletionOnlyLM instance internally). The padding behavior comes from that inner collator's chain. So the "WithPadding" in the name is describing what the collator does, not something it inherits. That said, I agree the naming could be better. @oelachqar was suggesting a refactor of all collators into a universal one with options (padding=True, completion_only=True) actually. Exploring that now to see if that can be done in this PR

TextCompletionsCollatorWithPadding,
)
from oumi.core.collators.tool_aware_completions_collator import (
ToolAwareCompletionsCollator,
)
from oumi.core.collators.vision_language_collator_with_padding import (
VisionLanguageCollatorWithPadding,
)
Expand Down Expand Up @@ -52,6 +55,8 @@ def build_data_collator(
- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
`TextCompletionsCollatorWithPadding`.
- "tool_aware_completions_only": Uses `ToolAwareCompletionsCollator`.
Correctly masks tool results in tool-calling conversations.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.

Expand Down Expand Up @@ -149,6 +154,31 @@ def build_data_collator(
debug=debug,
**kwargs,
)
elif collator_name == "tool_aware_completions_only":
response_template = kwargs.pop("response_template", None)
end_of_turn_template = kwargs.pop("end_of_turn_template", None)
mask_tool_calls = kwargs.pop("mask_tool_calls", False)
tool_call_start_template = kwargs.pop("tool_call_start_template", None)

if not response_template:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly we made a collator that can identify this itself?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my understanding and research, there isn't a collator in the codebase that auto-identifies the response template. They all need it provided explicitly or fall back to Llama defaults.

# Default to Llama-style templates if not provided

raise ValueError(
f"'response_template' is required for '{collator_name}'. "
"Set it in collator_kwargs (e.g. '<|im_start|>assistant\\n')."
)
if not end_of_turn_template:
raise ValueError(
f"'end_of_turn_template' is required for '{collator_name}'. "
"Set it in collator_kwargs (e.g. '<|im_end|>')."
)

return ToolAwareCompletionsCollator(
response_template=response_template,
end_of_turn_template=end_of_turn_template,
mask_tool_calls=mask_tool_calls,
tool_call_start_template=tool_call_start_template,
tokenizer=tokenizer,
**kwargs,
)
Comment thread
shanghongsim marked this conversation as resolved.
Outdated
raise ValueError(f"Unknown data collator name: '{collator_name}'")


Expand Down
210 changes: 210 additions & 0 deletions src/oumi/core/collators/tool_aware_completions_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Any

from transformers.data.data_collator import DataCollatorForLanguageModeling


class ToolAwareCompletionsCollator(DataCollatorForLanguageModeling):
r"""Completion-only collator that correctly masks tool results and tool calls.

The standard ``DataCollatorForCompletionOnlyLM`` uses template matching to
decide which tokens to train on. It works well for simple user/assistant
conversations, but breaks when a third role (``tool``) sits between two
assistant turns: the tool-result span ends up unmasked because the collator
only knows about the user-role marker as an "instruction" boundary.

This collator takes a different approach:

1. Start with **all labels masked** (-100).
2. Find every assistant response span by locating ``response_template``
tokens and scanning forward to the next ``end_of_turn_template``.
3. **Unmask** those spans so the model trains on them.
4. Optionally **re-mask** spans that contain tool-call content (controlled
by ``mask_tool_calls``), so the model only trains on plain-text replies.

Because the algorithm never relies on user/instruction markers, it handles
any number of tool turns, parallel tool calls, and multi-turn conversations
correctly.

Args:
response_template: String or token-ID list that marks the *start* of an
assistant response (e.g. ``"<|im_start|>assistant\n"`` for SmolLM2
or ``"[/INST]"`` for Llama-2).
end_of_turn_template: String or token-ID list that marks the *end* of a
turn (e.g. ``"<|im_end|>"`` for SmolLM2 or ``"</s>"`` for Llama).
mask_tool_calls: When ``True``, assistant spans that contain
``tool_call_start_template`` are re-masked. Set this to ``True``
if you only want to train on plain-text final answers. Defaults to
``False`` (train on all assistant output including tool calls).
tool_call_start_template: String or token-ID list that marks the start
of a tool-call block inside an assistant turn (e.g.
``"<tool_call>"``). Required when ``mask_tool_calls=True``.
ignore_index: Value used for masked labels. Must match the
``ignore_index`` of the loss function (default: -100).
"""

def __init__(
self,
response_template: str | list[int],
end_of_turn_template: str | list[int],
*args,
mask_tool_calls: bool = False,
tool_call_start_template: str | list[int] | None = None,
ignore_index: int = -100,
mlm: bool = False,
**kwargs,
):
"""Initializes ToolAwareCompletionsCollator."""
super().__init__(*args, mlm=mlm, **kwargs)
self.ignore_index = ignore_index

if isinstance(response_template, str):
self.response_token_ids: list[int] = self.tokenizer.encode(
response_template, add_special_tokens=False
)
else:
self.response_token_ids = list(response_template)

if isinstance(end_of_turn_template, str):
self.end_of_turn_token_ids: list[int] = self.tokenizer.encode(
end_of_turn_template, add_special_tokens=False
)
else:
self.end_of_turn_token_ids = list(end_of_turn_template)

self.mask_tool_calls = mask_tool_calls
self.tool_call_start_token_ids: list[int] | None = None
if mask_tool_calls:
if tool_call_start_template is None:
raise ValueError(
"tool_call_start_template must be provided "
"when mask_tool_calls=True"
)
if isinstance(tool_call_start_template, str):
self.tool_call_start_token_ids = self.tokenizer.encode(
tool_call_start_template, add_special_tokens=False
)
else:
self.tool_call_start_token_ids = list(tool_call_start_template)

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _find_pattern(self, seq: list[int], pattern: list[int]) -> list[int]:
"""Return all start positions where *pattern* appears in *seq*."""
plen = len(pattern)
if plen == 0:
return []
first = pattern[0]
positions = []
for i in range(len(seq) - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
positions.append(i)
return positions

def _span_contains(
self, seq: list[int], span_start: int, span_end: int, pattern: list[int]
) -> bool:
"""Return True if *pattern* appears anywhere in seq[span_start:span_end]."""
plen = len(pattern)
if plen == 0:
return False
first = pattern[0]
for i in range(span_start, span_end - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
return True
return False

# ------------------------------------------------------------------
# Main collation
# ------------------------------------------------------------------

def torch_call(
self, examples: list[list[int] | Any | dict[str, Any]]
) -> dict[str, Any]:
"""Collates examples and applies tool-aware label masking."""
# Let the base class handle padding and create labels = input_ids.
batch = super().torch_call(examples)

resp_len = len(self.response_token_ids)
pad_token_id = self.tokenizer.pad_token_id

for i in range(len(examples)):
# Step 1: mask everything.
batch["labels"][i, :] = self.ignore_index

seq: list[int] = batch["input_ids"][i].tolist()

# Compute the effective sequence length excluding trailing padding.
# This prevents false matches when end_of_turn_token_ids overlaps
# with the pad token (common: e.g. <|im_end|> = eos = pad).
if pad_token_id is not None:
n = len(seq)
while n > 0 and seq[n - 1] == pad_token_id:
n -= 1
else:
n = len(seq)

# Step 2: find every assistant response start position
# (within the non-padded region only).
resp_positions = self._find_pattern(seq[:n], self.response_token_ids)

if len(resp_positions) == 0:
warnings.warn(
f"Could not find response template in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. "
"This instance will be ignored in loss calculation.",
UserWarning,
)
continue

for resp_pos in resp_positions:
# Content starts right after the response_template tokens.
content_start = resp_pos + resp_len

# Step 3: find the next end_of_turn after content_start
# (within the non-padded region only).
eot_positions = self._find_pattern(
seq[content_start:n], self.end_of_turn_token_ids
)
if eot_positions:
content_end = content_start + eot_positions[0]
else:
# No closing marker found — unmask to end of real content.
content_end = n

if content_start >= content_end:
continue

# Step 4: optionally skip tool-call spans.
if self.mask_tool_calls and self.tool_call_start_token_ids is not None:
if self._span_contains(
seq,
content_start,
content_end,
self.tool_call_start_token_ids,
):
# Leave this span masked.
continue

# Step 5: unmask this assistant response span.
batch["labels"][i, content_start:content_end] = batch["input_ids"][
i, content_start:content_end
]

return batch
8 changes: 8 additions & 0 deletions src/oumi/core/configs/params/data_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,16 @@ class DatasetSplitParams(BaseParams):

- "text_with_padding": Dynamically pads the inputs received to
the longest length.
- "text_completions_only_with_padding": Uses template matching to
mask non-assistant tokens. Works for simple user/assistant turns.
- "tool_aware_completions_only": Correctly masks tool results in
tool-calling conversations. Requires ``response_template`` and
``end_of_turn_template`` in ``collator_kwargs``. Optionally set
``mask_tool_calls=True`` and ``tool_call_start_template`` to also
mask assistant tool-call turns.
- "vision_language_with_padding": Uses VisionLanguageCollator
for image+text multi-modal data.
- "vision_language_sft": Uses VisionLanguageSftCollator.

If None, then a default collator will be assigned.
"""
Expand Down
Loading
Loading