Skip to content

Add span-based masking to DataCollatorForCompletionOnlyLM#2368

Closed
shanghongsim wants to merge 14 commits intomainfrom
shanghong/collator-span-masking
Closed

Add span-based masking to DataCollatorForCompletionOnlyLM#2368
shanghongsim wants to merge 14 commits intomainfrom
shanghong/collator-span-masking

Conversation

@shanghongsim
Copy link
Copy Markdown
Contributor

@shanghongsim shanghongsim commented Apr 14, 2026

Description

Extend DataCollatorForCompletionOnlyLM with span-based masking (detecting assistant turn boundaries via response + end-of-turn tokens) beyond instruction based masking (matching instruction/response string pairs) supported currently. Additionally, allow explicit control of which part of response to train on using train_target parameter:

  • assistant_turn: Masks everything, then unmasks each assistant response bounded by response_template .. end_of_turn_template (inclusive of EOT).
  • final_assistant_turn: Masks all tokens before the last response_template occurrence.

When train_target is not set, the mode is inferred from template presence. train_target ultimately has the final say on which template is used. Existing configs using collator_kwargs with instruction_template + response_template continue to work via the legacy path (with a deprecation warning). New configs should use train_target explicitly.

Related issues

N/A
Part 1 of 2 — part 2 is #2369 (adds the MaskingMethod enum for YAML config).

Before submitting

  • This PR only changes documentation.
  • Did you read the contributor guideline Pull Request guidelines?
  • Did you link the issue(s) related to this PR in the section above?
  • Did you add / update tests where needed?

Reviewers

At least one review from a member of oumi-ai/oumi-staff is required.

Merge tool-aware masking directly into the existing TRL-derived collator.
Adds masking_method parameter with three strategies: assistant_turn,
assistant_turn_no_tools, and final_assistant_turn.

- Add _find_pattern, _span_contains, _apply_span_masking to DataCollatorForCompletionOnlyLM
- Rename instruction_prefix/response_prefix to instruction_template/response_template
- Add end_of_turn_template, tool_call_start_template, masking_method params
- Update builder to pass new params and support span-based masking
- Add comprehensive tests for all masking strategies
@gitar-bot
Copy link
Copy Markdown

gitar-bot bot commented Apr 14, 2026

Gitar is working

Gitar

Comment thread src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py Outdated
The wrapper never uses list[int] directly — it passes through to
DataCollatorForCompletionOnlyLM. All callers pass strings.
Update tests to pass string templates instead of token ID lists.
Raise ValueError at init time when masking_method is assistant_turn
or assistant_turn_no_tools but end_of_turn_template is None.
Previously this silently passed init and crashed with AssertionError
on the first batch during training.
@shanghongsim shanghongsim requested a review from oelachqar April 14, 2026 20:34
…class level

Replace 4 repeated isinstance/encode blocks with a single helper method.
Move _KNOWN_MASKING_METHODS from __init__ local to class-level constant.
Comment thread src/oumi/core/configs/params/data_params.py
@shanghongsim shanghongsim requested review from jgreer013 and lefft April 14, 2026 21:53
Replace reference to mask_tool_calls=True with the correct usage:
masking_method='assistant_turn_no_tools'.
Comment thread src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py Outdated
Emit DeprecationWarning when the collator infers
_legacy_instruction_response masking from the presence of
instruction_template. Guides users toward masking_method.
Comment thread src/oumi/builders/collators.py Outdated
When masking_method is not explicit and both end_of_turn_template and
tool_call_start_template are present, infer assistant_turn_no_tools
instead of assistant_turn.
Replace nested if/elif chain with a classmethod that validates
or infers masking_method from template presence. Keeps __init__
focused on tokenization and validation.
@shanghongsim shanghongsim force-pushed the shanghong/collator-span-masking branch from 8309130 to 4c081d7 Compare April 14, 2026 22:50
Comment on lines +102 to +110
if masking_method not in cls._KNOWN_MASKING_METHODS:
valid = sorted(
cls._KNOWN_MASKING_METHODS - {"_legacy_instruction_response"}
)
raise ValueError(
f"Unknown masking_method='{masking_method}'. "
f"Must be one of: {valid}"
)
return masking_method

This comment was marked as outdated.

Comment on lines +94 to +96
Priority (first match wins):
1. Explicit masking_method (validated)
2. end_of_turn + tool_call_start → assistant_turn_no_tools

This comment was marked as outdated.

Drop the assistant_turn_no_tools masking mode and the
tool_call_start_template parameter it depended on, simplifying
the span-based masking to only support assistant_turn and
final_assistant_turn strategies.
Remove hardcoded Llama-style instruction/response template defaults
from build_data_collator — callers must now provide response_template
explicitly. Inline the single-use _collate method in the wrapper.
Replace the classmethod + intermediate booleans with a simple
if/elif chain directly in __init__. Rename _KNOWN_MASKING_METHODS
to _VALID_MASKING_METHODS since the legacy path is internal.
Comment on lines +130 to +134
if not kwargs.get("response_template"):
raise ValueError(
"'text_completions_only_with_padding' requires a "
"response_template. Provide it via collator_kwargs."
)

This comment was marked as outdated.

…tant_turns

Rename for clarity: the parameter selects what to train on, not
what to mask. assistant_turn becomes all_assistant_turns to
distinguish from final_assistant_turn.
@shanghongsim shanghongsim force-pushed the shanghong/collator-span-masking branch from 86c00e7 to df0d473 Compare April 16, 2026 22:04
Comment on lines +139 to 144
ignore_index=(
label_ignore_index if label_ignore_index is not None else -100
),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Passing ignore_index in collator_kwargs will cause a TypeError due to receiving multiple values for the same keyword argument in the TextCompletionsCollatorWithPadding constructor.
Severity: MEDIUM

Suggested Fix

Before calling the TextCompletionsCollatorWithPadding constructor, remove the ignore_index key from the kwargs dictionary to prevent the argument collision. For example: kwargs.pop("ignore_index", None).

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.

Location: src/oumi/builders/collators.py#L139-L144

Potential issue: The `build_data_collator` function calls the
`TextCompletionsCollatorWithPadding` constructor, passing `ignore_index` as an explicit
keyword argument while also unpacking `collator_kwargs` into the same call. If a user's
configuration includes `ignore_index` within `collator_kwargs`, the argument will be
passed twice. This will raise a `TypeError` because the constructor receives multiple
values for the `ignore_index` keyword argument. The code does not currently handle this
potential collision by removing the key from the `kwargs` dictionary before the call.

@shanghongsim
Copy link
Copy Markdown
Contributor Author

Merged into #2369 which now targets main directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant