Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/oumi/builders/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def build_data_collator(

- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
`TextCompletionsCollatorWithPadding`.
`TextCompletionsCollatorWithPadding`. Supports optional
``end_of_turn_template`` for tool-aware span-based masking.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.

Expand Down Expand Up @@ -129,13 +130,22 @@ def build_data_collator(
# Extract instruction and response templates from kwargs if provided
instruction_template = kwargs.pop("instruction_template", None)
response_template = kwargs.pop("response_template", None)
end_of_turn_template = kwargs.pop("end_of_turn_template", None)
mask_tool_calls = kwargs.pop("mask_tool_calls", False)
tool_call_start_template = kwargs.pop("tool_call_start_template", None)

# Only default to Llama-style instruction template when NOT using
# span-based masking (end_of_turn_template makes instruction_prefix
# unnecessary since masking is handled by response/eot spans).
if end_of_turn_template is None:
instruction_prefix = (
instruction_template
if instruction_template
else "<|start_header_id|>user<|end_header_id|>\n\n"
)
else:
instruction_prefix = instruction_template # may be None, that's fine

# Default to Llama-style templates if not provided
instruction_prefix = (
instruction_template
if instruction_template
else "<|start_header_id|>user<|end_header_id|>\n\n"
)
response_prefix = (
response_template
if response_template
Expand All @@ -147,6 +157,12 @@ def build_data_collator(
instruction_prefix=instruction_prefix,
response_prefix=response_prefix,
debug=debug,
end_of_turn_template=end_of_turn_template,
mask_tool_calls=mask_tool_calls,
tool_call_start_template=tool_call_start_template,
ignore_index=(
label_ignore_index if label_ignore_index is not None else -100
),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
Expand Down
22 changes: 19 additions & 3 deletions src/oumi/core/collators/text_completions_collator_with_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,38 @@ class TextCompletionsCollatorWithPadding:
def __init__(
self,
tokenizer: BaseTokenizer,
instruction_prefix: str,
response_prefix: str,
response_prefix: str | list[int],
instruction_prefix: str | list[int] | None = None,
debug: bool = False,
end_of_turn_template: str | list[int] | None = None,
mask_tool_calls: bool = False,
tool_call_start_template: str | list[int] | None = None,
ignore_index: int = -100,
):
"""Custom collator for text LLM training.

Args:
tokenizer: The tokenizer used for encoding the data.
instruction_prefix: The prefix marking the beginning of the user instruction.
response_prefix: The prefix marking the beginning of the assistant response.
instruction_prefix: The prefix marking the beginning of the user instruction.
Optional when using span-based masking via end_of_turn_template.
debug: If True, enables debug mode for logging.
end_of_turn_template: String or token-ID list marking end of turn.
When provided, enables span-based masking for tool-aware conversations.
mask_tool_calls: When True, re-masks assistant spans containing tool calls.
tool_call_start_template: String or token-ID list marking tool-call start.
Required when mask_tool_calls=True.
ignore_index: Value used for masked labels. Must match the ignore_index
of the loss function (default: -100).
"""
self._default_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
instruction_template=instruction_prefix,
response_template=response_prefix,
end_of_turn_template=end_of_turn_template,
mask_tool_calls=mask_tool_calls,
tool_call_start_template=tool_call_start_template,
ignore_index=ignore_index,
)

if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
Expand Down
134 changes: 133 additions & 1 deletion src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
response_template: str | list[int],
instruction_template: str | list[int] | None = None,
*args,
end_of_turn_template: str | list[int] | None = None,
mask_tool_calls: bool = False,
tool_call_start_template: str | list[int] | None = None,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
Expand Down Expand Up @@ -60,6 +63,32 @@ def __init__(
# The user already provides the token ids
self.response_token_ids = response_template

# Tool-aware span-based masking parameters
self.end_of_turn_template = end_of_turn_template
if isinstance(end_of_turn_template, str):
self.end_of_turn_token_ids: list[int] | None = self.tokenizer.encode(
end_of_turn_template, add_special_tokens=False
)
elif end_of_turn_template is not None:
self.end_of_turn_token_ids = list(end_of_turn_template)
else:
self.end_of_turn_token_ids = None

self.mask_tool_calls = mask_tool_calls
self.tool_call_start_token_ids: list[int] | None = None
if mask_tool_calls:
if tool_call_start_template is None:
raise ValueError(
"tool_call_start_template must be provided "
"when mask_tool_calls=True"
)
if isinstance(tool_call_start_template, str):
self.tool_call_start_token_ids = self.tokenizer.encode(
tool_call_start_template, add_special_tokens=False
)
else:
self.tool_call_start_token_ids = list(tool_call_start_template)

if (
not self.mlm
and self.instruction_template
Expand All @@ -78,13 +107,116 @@ def __init__(
self.ignore_index = ignore_index
self.padding_free = padding_free

@staticmethod
def _find_pattern(seq: list[int], pattern: list[int]) -> list[int]:
"""Return all start positions where *pattern* appears in *seq*."""
plen = len(pattern)
if plen == 0:
return []
first = pattern[0]
positions = []
for i in range(len(seq) - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
positions.append(i)
return positions

@staticmethod
def _span_contains(
seq: list[int], span_start: int, span_end: int, pattern: list[int]
) -> bool:
"""Return True if *pattern* appears anywhere in seq[span_start:span_end]."""
plen = len(pattern)
if plen == 0:
return False
first = pattern[0]
for i in range(span_start, span_end - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
return True
return False

def _apply_span_masking(
self, batch: dict[str, Any], examples: list[list[int] | Any | dict[str, Any]]
) -> None:
"""Apply span-based label masking for tool-aware conversations.

This masks everything, then selectively unmasks assistant response
spans delimited by response_template and end_of_turn_template.
"""
resp_ids = self.response_token_ids
eot_ids = self.end_of_turn_token_ids
assert eot_ids is not None # Caller checks end_of_turn_template is not None
resp_len = len(resp_ids)
pad_token_id = self.tokenizer.pad_token_id

for i in range(len(examples)):
# Step 1: mask everything.
batch["labels"][i, :] = self.ignore_index

seq: list[int] = batch["input_ids"][i].tolist()

# Compute effective sequence length excluding trailing padding.
# Prevents false matches when end_of_turn_token_ids overlaps
# with the pad token (common: e.g. <|im_end|> = eos = pad).
if pad_token_id is not None:
n = len(seq)
while n > 0 and seq[n - 1] == pad_token_id:
n -= 1
else:
n = len(seq)

# Step 2: find every assistant response start position.
resp_positions = self._find_pattern(seq[:n], resp_ids)

if len(resp_positions) == 0:
warnings.warn(
f"Could not find response template in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. "
"This instance will be ignored in loss calculation.",
UserWarning,
)
continue

for resp_pos in resp_positions:
content_start = resp_pos + resp_len

# Step 3: find the next end_of_turn after content_start.
eot_positions = self._find_pattern(seq[content_start:n], eot_ids)
if eot_positions:
content_end = content_start + eot_positions[0]
else:
content_end = n

if content_start >= content_end:
continue

# Step 4: optionally skip tool-call spans.
if self.mask_tool_calls and self.tool_call_start_token_ids is not None:
if self._span_contains(
seq,
content_start,
content_end,
self.tool_call_start_token_ids,
):
continue

# Step 5: unmask this assistant response span.
batch["labels"][i, content_start:content_end] = batch["input_ids"][
i, content_start:content_end
]
Comment on lines +195 to +205
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The _apply_span_masking method incorrectly excludes EOT tokens from training labels, preventing the model from learning to emit stopping signals.
Severity: HIGH

Suggested Fix

To include the EOT tokens in the training labels, the content_end calculation in _apply_span_masking should be adjusted. It should be set to content_start + eot_positions[0] + len(eot_ids) to ensure the slice includes the full EOT token sequence in the unmasked labels.

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent.
Verify if this is a real issue. If it is, propose a fix; if not, explain why it's not
valid.

Location: src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py#L184-L205

Potential issue: When using `end_of_turn_template` for span-based masking, the
`_apply_span_masking` method incorrectly excludes end-of-turn (EOT) tokens from the
training labels. The slice `batch["labels"][i, content_start:content_end]` is calculated
such that `content_end` marks the beginning of the EOT token sequence, effectively
masking it. Standard supervised fine-tuning requires including these tokens in the loss
so the model learns to generate them as stopping signals. Models trained with this
collator will not learn when to stop generating, leading to uncontrolled output during
inference. This behavior appears to be an unintended side effect of a change focused on
masking tool responses.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be an issue. Trl's collator also mask out end-of-turn (EOT) tokens.


# ------------------------------------------------------------------
# Main collation
# ------------------------------------------------------------------

def torch_call(
self, examples: list[list[int] | Any | dict[str, Any]]
) -> dict[str, Any]:
"""Collates a list of examples into a batch."""
batch = super().torch_call(examples)

if self.instruction_template is None:
if self.end_of_turn_template is not None:
self._apply_span_masking(batch, examples)
elif self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None

Expand Down
6 changes: 6 additions & 0 deletions src/oumi/core/configs/params/data_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,14 @@ class DatasetSplitParams(BaseParams):

- "text_with_padding": Dynamically pads the inputs received to
the longest length.
- "text_completions_only_with_padding": Uses template matching to
mask non-assistant tokens. Works for simple user/assistant turns.
Supports optional ``end_of_turn_template`` in ``collator_kwargs``
for tool-aware span-based masking. When set, also supports
``mask_tool_calls=True`` and ``tool_call_start_template``.
- "vision_language_with_padding": Uses VisionLanguageCollator
for image+text multi-modal data.
- "vision_language_sft": Uses VisionLanguageSftCollator.

If None, then a default collator will be assigned.
"""
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/builders/test_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,50 @@ def test_build_collator_from_config_with_collator(
assert callable(collator)


def test_build_data_collator_text_completions_with_tool_kwargs(mock_tokenizer):
collator_name = "text_completions_only_with_padding"
resp = "<|im_start|>" + "assistant\n"
eot = "<|im_end|>"

# Basic build with end_of_turn_template
collator = build_data_collator(
collator_name,
mock_tokenizer,
max_length=None,
response_template=resp,
end_of_turn_template=eot,
)
assert collator is not None
assert callable(collator)

# Default label_ignore_index is forwarded
assert collator._default_collator.ignore_index == constants.LABEL_IGNORE_INDEX

# Custom label_ignore_index is forwarded
collator_custom = build_data_collator(
collator_name,
mock_tokenizer,
max_length=None,
label_ignore_index=-200,
response_template=resp,
end_of_turn_template=eot,
)
assert collator_custom._default_collator.ignore_index == -200

# With mask_tool_calls
collator_tc = build_data_collator(
collator_name,
mock_tokenizer,
max_length=None,
response_template=resp,
end_of_turn_template=eot,
mask_tool_calls=True,
tool_call_start_template="<tool_call>",
)
assert collator_tc is not None
assert callable(collator_tc)


def test_build_collator_from_config_no_collator(mock_tokenizer):
training_config = TrainingConfig(
data=DataParams(
Expand Down
Loading
Loading