Add tool aware collator to mask tool response correctly#2356
Add tool aware collator to mask tool response correctly#2356shanghongsim wants to merge 5 commits intomainfrom
Conversation
Adds a new collator that masks labels for non-tool-call assistant completions, enabling models to learn tool-calling behavior specifically.
Register ToolAwareCompletionsCollator in the collator builder and add the corresponding CollatorType enum. Update collator docstrings and tests.
Replace hardcoded IGNORE = -100 with the shared constant from oumi.core.constants to be consistent with sibling test files.
The builder was not forwarding label_ignore_index to the collator constructor, so it always used the default -100 instead of the configured value. Add the missing parameter and a builder test.
|
|
||
| import oumi.core.constants as constants | ||
| from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding | ||
| from oumi.core.collators.text_completions_collator_with_padding import ( |
There was a problem hiding this comment.
It seems like it inherits padding support from the base class, should we add "WithPadding" to its name? Or are we moving away from that naming?
There was a problem hiding this comment.
Just to clarify this class actually uses composition rather than inheritance (it wraps a transformers DataCollatorForCompletionOnlyLM instance internally). The padding behavior comes from that inner collator's chain. So the "WithPadding" in the name is describing what the collator does, not something it inherits. That said, I agree the naming could be better. @oelachqar was suggesting a refactor of all collators into a universal one with options (padding=True, completion_only=True) actually. Exploring that now to see if that can be done in this PR
| mask_tool_calls = kwargs.pop("mask_tool_calls", False) | ||
| tool_call_start_template = kwargs.pop("tool_call_start_template", None) | ||
|
|
||
| if not response_template: |
There was a problem hiding this comment.
If I remember correctly we made a collator that can identify this itself?
There was a problem hiding this comment.
Based on my understanding and research, there isn't a collator in the codebase that auto-identifies the response template. They all need it provided explicitly or fall back to Llama defaults.
oumi/src/oumi/builders/collators.py
Line 133 in 81611e4
Extend the existing collator with optional end_of_turn_template, mask_tool_calls, and tool_call_start_template parameters for span-based label masking in tool-calling conversations. Remove the separate ToolAwareCompletionsCollator class and forward label_ignore_index through TextCompletionsCollatorWithPadding.
| seq, | ||
| content_start, | ||
| content_end, | ||
| self.tool_call_start_token_ids, | ||
| ): | ||
| continue | ||
|
|
||
| # Step 5: unmask this assistant response span. | ||
| batch["labels"][i, content_start:content_end] = batch["input_ids"][ | ||
| i, content_start:content_end | ||
| ] |
There was a problem hiding this comment.
Bug: The _apply_span_masking method incorrectly excludes EOT tokens from training labels, preventing the model from learning to emit stopping signals.
Severity: HIGH
Suggested Fix
To include the EOT tokens in the training labels, the content_end calculation in _apply_span_masking should be adjusted. It should be set to content_start + eot_positions[0] + len(eot_ids) to ensure the slice includes the full EOT token sequence in the unmasked labels.
Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent.
Verify if this is a real issue. If it is, propose a fix; if not, explain why it's not
valid.
Location: src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py#L184-L205
Potential issue: When using `end_of_turn_template` for span-based masking, the
`_apply_span_masking` method incorrectly excludes end-of-turn (EOT) tokens from the
training labels. The slice `batch["labels"][i, content_start:content_end]` is calculated
such that `content_end` marks the beginning of the EOT token sequence, effectively
masking it. Standard supervised fine-tuning requires including these tokens in the loss
so the model learns to generate them as stopping signals. Models trained with this
collator will not learn when to stop generating, leading to uncontrolled output during
inference. This behavior appears to be an unintended side effect of a change focused on
masking tool responses.
There was a problem hiding this comment.
Should not be an issue. Trl's collator also mask out end-of-turn (EOT) tokens.
Description
Currently, the completions only collator does not mask tool responses (Role.TOOL) correctly.
Without
instruction_template, tool results are masked correctly. But withinstruction_template, things get weird and some tool results are unmasked. Omittinginstruction_templatecauses the collator to find the lastresponse_templateand masks everything before it (only the final assistant turn trains). Withinstruction_template, the collator searches forinstruction_templateandresponse_templatepairs and masks everything in between.Case 1: without
instruction_templateWithout inst template, collator masks everything before the last resp template. So loss only sees
Case 2: with
instruction_templateWith
instruction_template, it masks everything between a inst and resp template. Since there isn't an inst template before the tool result, it does not get masked properly.In case 2, the loss sees tool result, which is incorrect.
To confirm my hypothesis of no inst template before the tool result being the cause, I experimented with adding a user turn before the tool call and masking works correctly in that case.
ToolAwareCompletionsCollator
With the new
ToolAwareCompletionsCollator, tool results are masked properlyWithout

instruction_templateWith

instruction_templateRelated issues
Fixes # (issue)
Before submitting
Reviewers
At least one review from a member of
oumi-ai/oumi-staffis required.