Skip to content

Add tool aware collator to mask tool response correctly#2356

Closed
shanghongsim wants to merge 5 commits intomainfrom
shanghong/tool-aware-collator
Closed

Add tool aware collator to mask tool response correctly#2356
shanghongsim wants to merge 5 commits intomainfrom
shanghong/tool-aware-collator

Conversation

@shanghongsim
Copy link
Copy Markdown
Contributor

@shanghongsim shanghongsim commented Apr 8, 2026

Description

Currently, the completions only collator does not mask tool responses (Role.TOOL) correctly.

Without instruction_template, tool results are masked correctly. But with instruction_template, things get weird and some tool results are unmasked. Omitting instruction_template causes the collator to find the last response_template and masks everything before it (only the final assistant turn trains). With instruction_template, the collator searches for instruction_template and response_template pairs and masks everything in between.

Case 1: without instruction_template

RESPONSE_TEMPLATE = "<|im_start|>assistant\n"

old_no_inst = DataCollatorForCompletionOnlyLM(
    response_template=RESPONSE_TEMPLATE,
    instruction_template=None,
    tokenizer=tokenizer,
)
b = old_no_inst.torch_call([token_ids])
labels_A = b["labels"][0].tolist()
summarise("Case A", labels_A, N)
show_masking(b["input_ids"][0].tolist(), labels_A, tokenizer)
image

Without inst template, collator masks everything before the last resp template. So loss only sees

<|im_start|>assistant
The weather in Paris is sunny and 18°C.<|im_end|>

Case 2: with instruction_template

With instruction_template, it masks everything between a inst and resp template. Since there isn't an inst template before the tool result, it does not get masked properly.

INSTRUCTION_TEMPLATE = "<|im_start|>user\n"

old_with_inst = DataCollatorForCompletionOnlyLM(
    response_template=RESPONSE_TEMPLATE,
    instruction_template=INSTRUCTION_TEMPLATE,
    tokenizer=tokenizer,
)
b2 = old_with_inst.torch_call([token_ids])
labels_B = b2["labels"][0].tolist()
summarise("Case B", labels_B, N)
show_masking(b2["input_ids"][0].tolist(), labels_B, tokenizer)
image

In case 2, the loss sees tool result, which is incorrect.

...
<|im_start|>assistant
<tool_call>
{"name": "get_weather", "arguments": {"location": "Paris"}}
</tool_call><|im_end|>
...
<|im_start|>assistant
The weather in Paris is sunny and 18°C.<|im_end|>

To confirm my hypothesis of no inst template before the tool result being the cause, I experimented with adding a user turn before the tool call and masking works correctly in that case.

image

ToolAwareCompletionsCollator

With the new ToolAwareCompletionsCollator, tool results are masked properly

Without instruction_template
image

With instruction_template
image

Related issues

Fixes # (issue)

Before submitting

  • This PR only changes documentation. (You can ignore the following checks in that case)
  • Did you read the contributor guideline Pull Request guidelines?
  • Did you link the issue(s) related to this PR in the section above?
  • Did you add / update tests where needed?

Reviewers

At least one review from a member of oumi-ai/oumi-staff is required.

Adds a new collator that masks labels for non-tool-call assistant
completions, enabling models to learn tool-calling behavior specifically.
Register ToolAwareCompletionsCollator in the collator builder and add
the corresponding CollatorType enum. Update collator docstrings and tests.
Replace hardcoded IGNORE = -100 with the shared constant from
oumi.core.constants to be consistent with sibling test files.
@shanghongsim shanghongsim marked this pull request as ready for review April 9, 2026 00:12
@gitar-bot
Copy link
Copy Markdown

gitar-bot bot commented Apr 9, 2026

Gitar is working

Gitar

Comment thread src/oumi/builders/collators.py Outdated
The builder was not forwarding label_ignore_index to the collator
constructor, so it always used the default -100 instead of the
configured value. Add the missing parameter and a builder test.

import oumi.core.constants as constants
from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding
from oumi.core.collators.text_completions_collator_with_padding import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it inherits padding support from the base class, should we add "WithPadding" to its name? Or are we moving away from that naming?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify this class actually uses composition rather than inheritance (it wraps a transformers DataCollatorForCompletionOnlyLM instance internally). The padding behavior comes from that inner collator's chain. So the "WithPadding" in the name is describing what the collator does, not something it inherits. That said, I agree the naming could be better. @oelachqar was suggesting a refactor of all collators into a universal one with options (padding=True, completion_only=True) actually. Exploring that now to see if that can be done in this PR

Comment thread src/oumi/builders/collators.py Outdated
mask_tool_calls = kwargs.pop("mask_tool_calls", False)
tool_call_start_template = kwargs.pop("tool_call_start_template", None)

if not response_template:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly we made a collator that can identify this itself?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my understanding and research, there isn't a collator in the codebase that auto-identifies the response template. They all need it provided explicitly or fall back to Llama defaults.

# Default to Llama-style templates if not provided

Extend the existing collator with optional end_of_turn_template,
mask_tool_calls, and tool_call_start_template parameters for
span-based label masking in tool-calling conversations. Remove the
separate ToolAwareCompletionsCollator class and forward
label_ignore_index through TextCompletionsCollatorWithPadding.
Comment on lines +195 to +205
seq,
content_start,
content_end,
self.tool_call_start_token_ids,
):
continue

# Step 5: unmask this assistant response span.
batch["labels"][i, content_start:content_end] = batch["input_ids"][
i, content_start:content_end
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The _apply_span_masking method incorrectly excludes EOT tokens from training labels, preventing the model from learning to emit stopping signals.
Severity: HIGH

Suggested Fix

To include the EOT tokens in the training labels, the content_end calculation in _apply_span_masking should be adjusted. It should be set to content_start + eot_positions[0] + len(eot_ids) to ensure the slice includes the full EOT token sequence in the unmasked labels.

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent.
Verify if this is a real issue. If it is, propose a fix; if not, explain why it's not
valid.

Location: src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py#L184-L205

Potential issue: When using `end_of_turn_template` for span-based masking, the
`_apply_span_masking` method incorrectly excludes end-of-turn (EOT) tokens from the
training labels. The slice `batch["labels"][i, content_start:content_end]` is calculated
such that `content_end` marks the beginning of the EOT token sequence, effectively
masking it. Standard supervised fine-tuning requires including these tokens in the loss
so the model learns to generate them as stopping signals. Models trained with this
collator will not learn when to stop generating, leading to uncontrolled output during
inference. This behavior appears to be an unintended side effect of a change focused on
masking tool responses.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be an issue. Trl's collator also mask out end-of-turn (EOT) tokens.

@shanghongsim
Copy link
Copy Markdown
Contributor Author

This is v1.

v2 with improved interface is in these PRs:

#2368

#2369

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants