Skip to content

Extend DataCollatorForCompletionOnlyLM to support correct tool result masking#2369

Open
shanghongsim wants to merge 31 commits intomainfrom
shanghong/masking-method-enum-v2
Open

Extend DataCollatorForCompletionOnlyLM to support correct tool result masking#2369
shanghongsim wants to merge 31 commits intomainfrom
shanghong/masking-method-enum-v2

Conversation

@shanghongsim
Copy link
Copy Markdown
Contributor

@shanghongsim shanghongsim commented Apr 14, 2026

Description

Context

Currently, the completions only collator does not consistently mask tool responses (Role.TOOL) correctly. For some chat templates (like Qwen), it accidentally gets it correct, while masking is consistently wrongly applied for llama chat templates. This PR enables the masking to be correctly applied for all chat templates, supporting all four roles (system, user, assistant and tool).

Without instruction_template, tool results are masked correctly. But with instruction_template, things get weird and some tool results are unmasked. Omitting instruction_template causes the collator to find the last response_template and masks everything before it (only the final assistant turn trains). With instruction_template, the collator searches for instruction_template and response_template pairs and masks everything in between.

Case 1: without instruction_template

RESPONSE_TEMPLATE = "<|im_start|>assistant\n"

old_no_inst = DataCollatorForCompletionOnlyLM(
    response_template=RESPONSE_TEMPLATE,
    instruction_template=None,
    tokenizer=tokenizer,
)
b = old_no_inst.torch_call([token_ids])
labels_A = b["labels"][0].tolist()
summarise("Case A", labels_A, N)
show_masking(b["input_ids"][0].tolist(), labels_A, tokenizer)
image

Without inst template, collator masks everything before the last resp template. So loss only sees

<|im_start|>assistant
The weather in Paris is sunny and 18°C.<|im_end|>

Case 2: with instruction_template

With instruction_template, it masks everything between a inst and resp template. Since there isn't an inst template before the tool result, it does not get masked properly.

INSTRUCTION_TEMPLATE = "<|im_start|>user\n"

old_with_inst = DataCollatorForCompletionOnlyLM(
    response_template=RESPONSE_TEMPLATE,
    instruction_template=INSTRUCTION_TEMPLATE,
    tokenizer=tokenizer,
)
b2 = old_with_inst.torch_call([token_ids])
labels_B = b2["labels"][0].tolist()
summarise("Case B", labels_B, N)
show_masking(b2["input_ids"][0].tolist(), labels_B, tokenizer)
image

In case 2, the loss sees tool result, which is incorrect.

...
<|im_start|>assistant
<tool_call>
{"name": "get_weather", "arguments": {"location": "Paris"}}
</tool_call><|im_end|>
...
<|im_start|>assistant
The weather in Paris is sunny and 18°C.<|im_end|>

To confirm my hypothesis of no inst template before the tool result being the cause, I experimented with adding a user turn before the tool call and masking works correctly in that case.

image

ToolAwareCompletionsCollator

With the new ToolAwareCompletionsCollator, tool results are masked properly

Without instruction_template
image

With instruction_template
image

Change 1: Extend DataCollatorForCompletionOnlyLM with span-based masking

Extend DataCollatorForCompletionOnlyLM with span-based masking (detecting assistant turn boundaries via response + end-of-turn tokens) beyond instruction based masking (matching instruction/response string pairs) supported currently.

Change 2: train_target enum for explicit control of which part of response to train on

  • assistant_turn: Masks everything, then unmasks each assistant response bounded by response_template .. end_of_turn_template (inclusive of EOT).
  • final_assistant_turn: Masks all tokens before the last response_template occurrence.

When train_target is set, no matter what templates have been provided, train_target has final authority on which combination is used. When train_target is not set, it is inferred from the templates provided. Eventually, we want users to specify train_target instead of relying on combinations of templates to control behavior. However, since we have many configs in production that do not use train_target and set templates explicitly, this shall be the interim solution to ensure old configs do not break. Existing configs using collator_kwargs with instruction_template + response_template continue to work via the legacy path (with a deprecation warning). New configs should use train_target explicitly. Users can set train_target and override the templates manually by setting collator_kwargs.

Before — users must provide model-specific token strings:

collator_name: "text_completions_only_with_padding"
collator_kwargs:
  response_template: "<|im_start|>assistant\n"
  end_of_turn_template: "<|im_end|>"

After — users express intent, templates auto-resolve from tokenizer:

collator_name: "text_completions_only_with_padding"
train_target: "assistant_turn"

Change 3: Auto detection of collator templates

Instead of fragile vocab and collator template matching (requires us to maintain collator templates for all popular models), shift to a more robust approach: apply the chat template to a known test conversation, then finds the assistant boundary strings in the rendered output.

Migration and deprecation considerations:

  1. Configs that only specify collator_name -> removed support for this with the removal of the default llama template fallback. This should be ok as no configs in production specify collator_name without collator_kwargs. Only less than 10 configs in OSS fall in this edgecase and they are being updated in this PR.
  2. Enterprise SFT configs -> they should still work with old legacy behavior. Will update them once OSS version is updated in API.

Open issues for next time

  • Movetrain_target into collator_kwargs -> I considered putting train_target inside collator_kwargs instead of making it a separate field, but collator_kwargs is an opaque dict. Users wouldn't know train_target exists unless they read the docs. A top-level field makes it more visible. This will eventually be solved when we shift to CollatorParam.
  • CollatorParams: This would mean changing every YAML from:
  collator_name: "text_completions_only_with_padding"
  collator_kwargs:
    train_target: "all_assistant_turns"

to

  collator:
    name: "text_completions_only_with_padding"
    train_target: "all_assistant_turns"

This is much cleaner but its a substantial refactor. In order to manage the scope of this PR, this shall be deferred to a future PR

  • Vision specific concerns -> future PR to neaten the vision concerns that are now scattered throughout the LM specific things
  • Per-split collator configuration -> build_collator_from_config builds a single collator from the training split's settings, which is reused for all splits (train, validation, test). Validation/test split-specific collator settings (e.g. a different train_target) are silently ignored. This is pre-existing behavior not introduced by this PR, but worth noting as a known limitation. Fixing it would require refactoring the training loop to build per-split collators.

Related issues

N/A

Before submitting

  • This PR only changes documentation.
  • Did you read the contributor guideline Pull Request guidelines?
  • Did you link the issue(s) related to this PR in the section above?
  • Did you add / update tests where needed?

Reviewers

At least one review from a member of oumi-ai/oumi-staff is required.

Merge tool-aware masking directly into the existing TRL-derived collator.
Adds masking_method parameter with three strategies: assistant_turn,
assistant_turn_no_tools, and final_assistant_turn.

- Add _find_pattern, _span_contains, _apply_span_masking to DataCollatorForCompletionOnlyLM
- Rename instruction_prefix/response_prefix to instruction_template/response_template
- Add end_of_turn_template, tool_call_start_template, masking_method params
- Update builder to pass new params and support span-based masking
- Add comprehensive tests for all masking strategies
@gitar-bot
Copy link
Copy Markdown

gitar-bot bot commented Apr 14, 2026

Gitar is working

Gitar

Comment thread src/oumi/builders/collators.py
The wrapper never uses list[int] directly — it passes through to
DataCollatorForCompletionOnlyLM. All callers pass strings.
Update tests to pass string templates instead of token ID lists.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 25c06fd to 9facd20 Compare April 14, 2026 20:18
Raise ValueError at init time when masking_method is assistant_turn
or assistant_turn_no_tools but end_of_turn_template is None.
Previously this silently passed init and crashed with AssertionError
on the first batch during training.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 9facd20 to 0ae3142 Compare April 14, 2026 20:25
…class level

Replace 4 repeated isinstance/encode blocks with a single helper method.
Move _KNOWN_MASKING_METHODS from __init__ local to class-level constant.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 0ae3142 to 3f4dba3 Compare April 14, 2026 20:40
@shanghongsim shanghongsim changed the title Add MaskingMethod enum for explicit SFT masking control Add MaskingMethod enum for simpler masking control Apr 14, 2026
Comment thread src/oumi/builders/collators.py Outdated
Comment thread src/oumi/builders/collators.py Outdated
Comment thread src/oumi/builders/collators.py Outdated
Replace reference to mask_tool_calls=True with the correct usage:
masking_method='assistant_turn_no_tools'.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from f957244 to 2b35a20 Compare April 14, 2026 21:55
Comment thread src/oumi/builders/collators.py Outdated
Emit DeprecationWarning when the collator infers
_legacy_instruction_response masking from the presence of
instruction_template. Guides users toward masking_method.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 2b35a20 to 676c5b5 Compare April 14, 2026 22:31
Comment thread src/oumi/builders/collators.py Outdated
When masking_method is not explicit and both end_of_turn_template and
tool_call_start_template are present, infer assistant_turn_no_tools
instead of assistant_turn.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch 2 times, most recently from c7a1308 to 9a80c42 Compare April 14, 2026 22:46
Replace nested if/elif chain with a classmethod that validates
or infers masking_method from template presence. Keeps __init__
focused on tokenization and validation.
@shanghongsim shanghongsim force-pushed the shanghong/collator-span-masking branch from 8309130 to 4c081d7 Compare April 14, 2026 22:50
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 9a80c42 to 778351f Compare April 14, 2026 22:50
@shanghongsim shanghongsim force-pushed the shanghong/collator-span-masking branch from 86c00e7 to df0d473 Compare April 16, 2026 22:04
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from e9ae1ef to 9913f12 Compare April 16, 2026 22:05
shanghongsim and others added 11 commits April 16, 2026 22:06
Drop the ASSISTANT_TURN_NO_TOOLS enum member and the
tool_call_start_template from builder templates, matching the
core collator simplification in PR1.
MaskingMethod was confusing — it sounded like the assistant turns
were being masked. TrainTarget with ALL_ASSISTANT_TURNS /
FINAL_ASSISTANT_TURN makes the intent clear: select what to
train on.
Instead of checking for marker tokens in the tokenizer vocabulary and
returning hardcoded template strings, render the tokenizer's own chat
template with sentinel content and extract response_template and
end_of_turn_template from the output. This works for any model with a
chat template without requiring per-family hardcoded entries.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove mutual exclusivity check so users can set train_target for
auto-resolution and override individual templates via collator_kwargs.
@shanghongsim shanghongsim force-pushed the shanghong/masking-method-enum-v2 branch from 9913f12 to fbc7087 Compare April 16, 2026 22:08
Move train_target inference from the collator into
build_collator_from_config so the builder is the single decision
point. The collator now receives a resolved train_target and only
validates structural invariants.

- Builder handles new config (train_target set, auto-detect
  templates) and old config (infer train_target from collator_kwargs)
- Collator: remove _resolve_train_target, make train_target required,
  derive _VALID_TRAIN_TARGETS from TrainTarget enum
- Add builder tests for old-recipe inference paths
@shanghongsim shanghongsim changed the title Add MaskingMethod enum for simpler masking control Add train_target-based masking to DataCollatorForCompletionOnlyLM Apr 17, 2026
@shanghongsim shanghongsim changed the base branch from shanghong/collator-span-masking to main April 17, 2026 00:16
@shanghongsim shanghongsim changed the title Add train_target-based masking to DataCollatorForCompletionOnlyLM Extend DataCollatorForCompletionOnlyLM to support correct tool result masking Apr 17, 2026
Comment thread src/oumi/builders/collators.py
Comment thread src/oumi/builders/collators.py Outdated
Comment thread src/oumi/builders/collators.py Outdated
Comment on lines +114 to +115
if not response_template.strip() or not end_of_turn_template.strip():
raise ValueError(_FALLBACK_MSG)

This comment was marked as outdated.

Comment on lines 243 to 249
debug=debug,
ignore_index=(
label_ignore_index if label_ignore_index is not None else -100
),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The TextCompletionsCollatorWithPadding constructor can receive ignore_index both as an explicit argument and within **kwargs, causing a TypeError if a user customizes it.
Severity: HIGH

Suggested Fix

Before calling the TextCompletionsCollatorWithPadding constructor, pop ignore_index from the kwargs dictionary. Use the popped value to determine the final ignore_index value to be passed, preventing the argument duplication.

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.

Location: src/oumi/builders/collators.py#L243-L249

Potential issue: When building the `text_completions_only_with_padding` collator in
`build_data_collator`, the `ignore_index` parameter is passed explicitly while also
being potentially present in the `**kwargs` passed to the
`TextCompletionsCollatorWithPadding` constructor. If a user specifies `ignore_index` in
their `collator_kwargs` configuration, this results in a `TypeError` because the
`__init__` method receives multiple values for the same keyword argument, causing a
runtime crash during training setup.

Raises:
ValueError: If templates cannot be extracted.
"""
msgs = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really clever, but we should maybe include a system instruction here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like modify it to extract system instruction? We are current no longer relying on system instruction for masking

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean ensure that when we extract the different template portions, we do so using an example conversation that also includes a system instruction

Or rather, perhaps do it twice, once on a conversation with a system instruction, and once on a conversation without.

Why? Because I'm a bit concerned of the scenario where this logic doesn't use system instructions when extracting, but the user does in their data, and somehow the resulting templates we extract using this methodology don't wind up working when the data includes the presence of a system instruction for some reason.

# Locate boundaries around the second turn pair
# to avoid system-prompt effects on the first turn.
try:
a1 = rendered.index(_SENTINEL_ASST)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: first_asst_start, and add _start to second turns

except ValueError:
raise ValueError(_FALLBACK_MSG)

# End-of-turn: common token-ID prefix of the two strings that
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract this to its own method

assert isinstance(_eot_decoded, str)
end_of_turn_template = _eot_decoded

# Response template: strip the EOT prefix to get just the assistant header.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, extract to its own method

if not response_template.strip() or not end_of_turn_template.strip():
raise ValueError(_FALLBACK_MSG)

# Qwen3 and similar reasoning models inject <think>...</think> into
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description scares me, I wonder if we could have a better workaround or a louder error

from oumi.core.configs.params.data_params import TrainTarget


class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what this collator does differently than the other one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, it would instruction_template and response_template pairs then mask everything in between (see L324). However, there isn't an instruction_template before the tool result, so tool result content does not get masked properly. Some models (like Qwen) accidentally mask it correctly because their chat template stores tool result under role.user. But for most models (like Llama etc), tool result is stored under a role.tool with a separate tag so masking is incorrectly applied for the reasons above. The new approach basically masks everything then only unmask portions between the response_template and end_of_turn_template. This is more robust to the different roles and specifics of how tool results are formatted and stored by different chat templates.

image

Copy link
Copy Markdown
Contributor

@lefft lefft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be sure to check whether template detection and collation happens properly when user or assistant turns begin with (one or more) leading \ns? We had issues with this breaking tokenization and leading to records being skipped in enterprise, would like to understand whether these changes will help, harm, or not impact that issue.

)


def _resolve_collator_templates(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been tested with gpt-oss models? They use a non-standard system for marking boundaries within responses, so wondering if boundary detection will work as it does for Qwen/Llama style templates.

collator_kwargs["train_target"] = "all_assistant_turns"
elif has_inst:
warnings.warn(
"Instruction-based masking is deprecated.\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a plan/intention to eliminate support for instruction-based masking altogether, or will we just keep it in a "deprecated" state?

Want to make sure we retain compatibility with existing enterprise configs (or verify that moving to this style doesn't change training dynamics before adapting enterprise configs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants