Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/oumi/builders/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def build_data_collator(

- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
`TextCompletionsCollatorWithPadding`.
`TextCompletionsCollatorWithPadding`. Supports optional
``end_of_turn_template`` for tool-aware span-based masking.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.

Expand Down Expand Up @@ -129,24 +130,33 @@ def build_data_collator(
# Extract instruction and response templates from kwargs if provided
instruction_template = kwargs.pop("instruction_template", None)
response_template = kwargs.pop("response_template", None)
end_of_turn_template = kwargs.pop("end_of_turn_template", None)
masking_method = kwargs.pop("masking_method", None)
tool_call_start_template = kwargs.pop("tool_call_start_template", None)

# Only default to Llama-style instruction template when NOT using
# span-based masking (end_of_turn_template makes it unnecessary).
if end_of_turn_template is None:
instruction_template = (
instruction_template
if instruction_template
else "<|start_header_id|>user<|end_header_id|>\n\n"
)
Comment thread
shanghongsim marked this conversation as resolved.
Outdated
Comment on lines +130 to +134

This comment was marked as outdated.


# Default to Llama-style templates if not provided
instruction_prefix = (
instruction_template
if instruction_template
else "<|start_header_id|>user<|end_header_id|>\n\n"
)
response_prefix = (
response_template
if response_template
else "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
if not response_template:
response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"

return TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
instruction_prefix=instruction_prefix,
response_prefix=response_prefix,
response_template=response_template,
instruction_template=instruction_template,
debug=debug,
masking_method=masking_method,
end_of_turn_template=end_of_turn_template,
tool_call_start_template=tool_call_start_template,
ignore_index=(
label_ignore_index if label_ignore_index is not None else -100
),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
Comment on lines +139 to 144
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Passing ignore_index in collator_kwargs will cause a TypeError due to receiving multiple values for the same keyword argument in the TextCompletionsCollatorWithPadding constructor.
Severity: MEDIUM

Suggested Fix

Before calling the TextCompletionsCollatorWithPadding constructor, remove the ignore_index key from the kwargs dictionary to prevent the argument collision. For example: kwargs.pop("ignore_index", None).

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.

Location: src/oumi/builders/collators.py#L139-L144

Potential issue: The `build_data_collator` function calls the
`TextCompletionsCollatorWithPadding` constructor, passing `ignore_index` as an explicit
keyword argument while also unpacking `collator_kwargs` into the same call. If a user's
configuration includes `ignore_index` within `collator_kwargs`, the argument will be
passed twice. This will raise a `TypeError` because the constructor receives multiple
values for the `ignore_index` keyword argument. The code does not currently handle this
potential collision by removing the key from the `kwargs` dictionary before the call.

Expand Down
28 changes: 22 additions & 6 deletions src/oumi/core/collators/text_completions_collator_with_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,38 @@ class TextCompletionsCollatorWithPadding:
def __init__(
self,
tokenizer: BaseTokenizer,
instruction_prefix: str,
response_prefix: str,
response_template: str,
instruction_template: str | None = None,
debug: bool = False,
masking_method: str | None = None,
end_of_turn_template: str | None = None,
tool_call_start_template: str | None = None,
ignore_index: int = -100,
):
"""Custom collator for text LLM training.

Args:
tokenizer: The tokenizer used for encoding the data.
instruction_prefix: The prefix marking the beginning of the user instruction.
response_prefix: The prefix marking the beginning of the assistant response.
response_template: String marking assistant response start.
instruction_template: String marking user instruction start.
debug: If True, enables debug mode for logging.
masking_method: Masking strategy — ``"assistant_turn"``,
``"assistant_turn_no_tools"``, or ``"final_assistant_turn"``.
end_of_turn_template: String marking the end of a turn.
Required for ``assistant_turn`` and ``assistant_turn_no_tools``.
tool_call_start_template: String marking tool-call start.
Required for ``assistant_turn_no_tools``.
ignore_index: Value used for masked labels. Must match the ignore_index
of the loss function (default: -100).
"""
self._default_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
instruction_template=instruction_prefix,
response_template=response_prefix,
instruction_template=instruction_template,
response_template=response_template,
masking_method=masking_method,
end_of_turn_template=end_of_turn_template,
tool_call_start_template=tool_call_start_template,
ignore_index=ignore_index,
)

if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
Expand Down
234 changes: 215 additions & 19 deletions src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,73 @@


class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""Data collator used for completion tasks.
"""Data collator for completion-only training.

Copied from `trl`'s `DataCollatorForCompletionOnlyLM` class.
Masks input labels so that the loss is only computed on specific
tokens (typically assistant responses), while ignoring other tokens
(system prompts, user messages, padding).

The ``masking_method`` parameter selects the masking strategy:

**``assistant_turn``**:
Span-based masking for multi-turn and tool-calling conversations.
Masks everything, then unmarks each assistant response span bounded
by ``response_template`` .. ``end_of_turn_template`` (inclusive of EOT).
Correctly handles interleaved tool results and parallel tool calls.

**``assistant_turn_no_tools``**:
Same as ``assistant_turn``, but additionally re-masks assistant
turns that contain tool-call content. Requires
``tool_call_start_template``. Only natural-language responses
contribute to the loss.

**``final_assistant_turn``**:
Masks all tokens before the *last* ``response_template`` occurrence.
Only the final assistant response is trained on. Suitable for
single-turn completions.

Args:
response_template: String or token IDs marking the start of an
assistant response. Required for all modes.
instruction_template: String or token IDs marking the start of a
user instruction. Legacy — only used with the instruction+response
fallback path.
masking_method: One of ``"assistant_turn"``,
``"assistant_turn_no_tools"``, ``"final_assistant_turn"``.
When None, inferred from template presence for backward compat.
end_of_turn_template: String or token IDs marking the end of a
conversational turn. Required for ``assistant_turn`` and
``assistant_turn_no_tools`` modes.
tool_call_start_template: String or token IDs marking the start
of a tool-call block. Required for ``assistant_turn_no_tools``.
mlm: Whether to use masked language modeling. Default False.
ignore_index: Label value for masked tokens. Default -100.
padding_free: Remove padding and add position_ids. Default False.
"""

_KNOWN_MASKING_METHODS = {
"assistant_turn",
"assistant_turn_no_tools",
"final_assistant_turn",
"_legacy_instruction_response",
}

def _tokenize_template(self, template: str | list[int] | None) -> list[int] | None:
"""Encode a template string into token IDs, or pass through if already IDs."""
if template is None:
return None
if isinstance(template, str):
return self.tokenizer.encode(template, add_special_tokens=False)
return list(template)

def __init__(
self,
response_template: str | list[int],
instruction_template: str | list[int] | None = None,
*args,
masking_method: str | None = None,
end_of_turn_template: str | list[int] | None = None,
tool_call_start_template: str | list[int] | None = None,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
Expand All @@ -39,26 +96,51 @@ def __init__(
"""Initializes the DataCollatorForCompletionOnlyLM."""
super().__init__(*args, mlm=mlm, **kwargs)

# Tokenize templates.
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(
self.instruction_template, # type: ignore
add_special_tokens=False,
)
self.instruction_token_ids = self._tokenize_template(instruction_template)
self.response_template = response_template
self.response_token_ids: list[int] = self._tokenize_template(response_template) # type: ignore[assignment]
self.end_of_turn_template = end_of_turn_template
self.end_of_turn_token_ids = self._tokenize_template(end_of_turn_template)

# Infer masking_method from template presence for backward compatibility.
if masking_method is not None:
if masking_method not in self._KNOWN_MASKING_METHODS:
valid_methods = sorted(
self._KNOWN_MASKING_METHODS - {"_legacy_instruction_response"}
)
raise ValueError(
f"Unknown masking_method='{masking_method}'. "
f"Must be one of: {valid_methods}"
)
self.masking_method = masking_method
elif end_of_turn_template is not None:
self.masking_method = "assistant_turn"
Comment thread
shanghongsim marked this conversation as resolved.
Outdated
elif instruction_template is None:
self.masking_method = "final_assistant_turn"
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template
self.masking_method = "_legacy_instruction_response"

self.response_template = response_template
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(
self.response_template, add_special_tokens=False
# Validate required templates for each masking method.
if self.masking_method in ("assistant_turn", "assistant_turn_no_tools"):
if end_of_turn_template is None:
raise ValueError(
"end_of_turn_template must be provided "
f"when masking_method='{self.masking_method}'"
)

self.mask_tool_calls = self.masking_method == "assistant_turn_no_tools"
self.tool_call_start_token_ids: list[int] | None = None
if self.mask_tool_calls:
if tool_call_start_template is None:
raise ValueError(
"tool_call_start_template must be provided "
"when masking_method='assistant_turn_no_tools'"
)
self.tool_call_start_token_ids = self._tokenize_template(
tool_call_start_template
)
else:
# The user already provides the token ids
self.response_token_ids = response_template

if (
not self.mlm
Expand All @@ -78,13 +160,127 @@ def __init__(
self.ignore_index = ignore_index
self.padding_free = padding_free

@staticmethod
def _find_pattern(seq: list[int], pattern: list[int]) -> list[int]:
"""Return all start positions where *pattern* appears in *seq*."""
plen = len(pattern)
if plen == 0:
return []
first = pattern[0]
positions = []
for i in range(len(seq) - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
positions.append(i)
return positions

@staticmethod
def _span_contains(
seq: list[int], span_start: int, span_end: int, pattern: list[int]
) -> bool:
"""Return True if *pattern* appears anywhere in seq[span_start:span_end]."""
plen = len(pattern)
if plen == 0:
return False
first = pattern[0]
for i in range(span_start, span_end - plen + 1):
if seq[i] == first and seq[i : i + plen] == pattern:
return True
return False

def _apply_span_masking(
self, batch: dict[str, Any], examples: list[list[int] | Any | dict[str, Any]]
) -> None:
"""Apply span-based masking for tool-aware conversations.

Masks all labels, then unmarks assistant response spans bounded by
response_template and end_of_turn_template (inclusive — the EOT token
is unmasked so the model learns to produce it). Optionally re-masks
spans that contain tool-call content.
"""
resp_ids = self.response_token_ids
eot_ids = self.end_of_turn_token_ids
assert eot_ids is not None # Caller checks end_of_turn_template is not None
resp_len = len(resp_ids)
pad_token_id = self.tokenizer.pad_token_id

for i in range(len(examples)):
# Step 1: mask everything.
batch["labels"][i, :] = self.ignore_index

seq: list[int] = batch["input_ids"][i].tolist()

# Compute effective sequence length excluding trailing padding.
# Prevents false matches when end_of_turn_token_ids overlaps
# with the pad token (common: e.g. <|im_end|> = eos = pad).
if pad_token_id is not None:
n = len(seq)
while n > 0 and seq[n - 1] == pad_token_id:
n -= 1
else:
n = len(seq)

# Step 2: find every assistant response start position.
resp_positions = self._find_pattern(seq[:n], resp_ids)

if len(resp_positions) == 0:
warnings.warn(
f"Could not find response template in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. "
"This instance will be ignored in loss calculation.",
UserWarning,
)
continue

for resp_pos in resp_positions:
content_start = resp_pos + resp_len

# Step 3: find the next end_of_turn after content_start.
eot_positions = self._find_pattern(seq[content_start:n], eot_ids)
if eot_positions:
content_end = content_start + eot_positions[0]
else:
content_end = n

if content_start >= content_end:
continue

# Step 4: optionally skip tool-call spans.
if self.mask_tool_calls and self.tool_call_start_token_ids is not None:
if self._span_contains(
seq,
content_start,
content_end,
self.tool_call_start_token_ids,
):
continue

# Step 5: unmask this assistant response span, including the
# end-of-turn token so the model learns when to stop.
if eot_positions:
eot_len = len(self.end_of_turn_token_ids) # type: ignore
unmask_end = content_end + eot_len
else:
# No EOT found — content_end == n (end of real content).
# Do NOT extend past n or we'd unmask into padding.
unmask_end = content_end
batch["labels"][i, content_start:unmask_end] = batch["input_ids"][
i, content_start:unmask_end
]

# ------------------------------------------------------------------
# Main collation
# ------------------------------------------------------------------

def torch_call(
self, examples: list[list[int] | Any | dict[str, Any]]
) -> dict[str, Any]:
"""Collates a list of examples into a batch."""
batch = super().torch_call(examples)

if self.instruction_template is None:
if self.masking_method in ("assistant_turn", "assistant_turn_no_tools"):
self._apply_span_masking(batch, examples)
elif self.masking_method == "final_assistant_turn":
# Response-only: unmask only the final assistant response.
for i in range(len(examples)):
response_token_ids_start_idx = None

Expand Down
6 changes: 6 additions & 0 deletions src/oumi/core/configs/params/data_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,14 @@ class DatasetSplitParams(BaseParams):

- "text_with_padding": Dynamically pads the inputs received to
the longest length.
- "text_completions_only_with_padding": Uses template matching to
mask non-assistant tokens. Works for simple user/assistant turns.
Supports optional ``end_of_turn_template`` in ``collator_kwargs``
for tool-aware span-based masking. When set, also supports
``mask_tool_calls=True`` and ``tool_call_start_template``.
- "vision_language_with_padding": Uses VisionLanguageCollator
for image+text multi-modal data.
- "vision_language_sft": Uses VisionLanguageSftCollator.

Comment thread
shanghongsim marked this conversation as resolved.
If None, then a default collator will be assigned.
"""
Expand Down
Loading
Loading