Extend DataCollatorForCompletionOnlyLM to support correct tool result masking#2369
Extend DataCollatorForCompletionOnlyLM to support correct tool result masking#2369shanghongsim wants to merge 31 commits intomainfrom
Conversation
Merge tool-aware masking directly into the existing TRL-derived collator. Adds masking_method parameter with three strategies: assistant_turn, assistant_turn_no_tools, and final_assistant_turn. - Add _find_pattern, _span_contains, _apply_span_masking to DataCollatorForCompletionOnlyLM - Rename instruction_prefix/response_prefix to instruction_template/response_template - Add end_of_turn_template, tool_call_start_template, masking_method params - Update builder to pass new params and support span-based masking - Add comprehensive tests for all masking strategies
The wrapper never uses list[int] directly — it passes through to DataCollatorForCompletionOnlyLM. All callers pass strings. Update tests to pass string templates instead of token ID lists.
25c06fd to
9facd20
Compare
Raise ValueError at init time when masking_method is assistant_turn or assistant_turn_no_tools but end_of_turn_template is None. Previously this silently passed init and crashed with AssertionError on the first batch during training.
9facd20 to
0ae3142
Compare
…class level Replace 4 repeated isinstance/encode blocks with a single helper method. Move _KNOWN_MASKING_METHODS from __init__ local to class-level constant.
0ae3142 to
3f4dba3
Compare
Replace reference to mask_tool_calls=True with the correct usage: masking_method='assistant_turn_no_tools'.
f957244 to
2b35a20
Compare
Emit DeprecationWarning when the collator infers _legacy_instruction_response masking from the presence of instruction_template. Guides users toward masking_method.
2b35a20 to
676c5b5
Compare
When masking_method is not explicit and both end_of_turn_template and tool_call_start_template are present, infer assistant_turn_no_tools instead of assistant_turn.
c7a1308 to
9a80c42
Compare
Replace nested if/elif chain with a classmethod that validates or infers masking_method from template presence. Keeps __init__ focused on tokenization and validation.
8309130 to
4c081d7
Compare
9a80c42 to
778351f
Compare
86c00e7 to
df0d473
Compare
e9ae1ef to
9913f12
Compare
Drop the ASSISTANT_TURN_NO_TOOLS enum member and the tool_call_start_template from builder templates, matching the core collator simplification in PR1.
MaskingMethod was confusing — it sounded like the assistant turns were being masked. TrainTarget with ALL_ASSISTANT_TURNS / FINAL_ASSISTANT_TURN makes the intent clear: select what to train on.
Instead of checking for marker tokens in the tokenizer vocabulary and returning hardcoded template strings, render the tokenizer's own chat template with sentinel content and extract response_template and end_of_turn_template from the output. This works for any model with a chat template without requiring per-family hardcoded entries.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove mutual exclusivity check so users can set train_target for auto-resolution and override individual templates via collator_kwargs.
9913f12 to
fbc7087
Compare
Move train_target inference from the collator into build_collator_from_config so the builder is the single decision point. The collator now receives a resolved train_target and only validates structural invariants. - Builder handles new config (train_target set, auto-detect templates) and old config (infer train_target from collator_kwargs) - Collator: remove _resolve_train_target, make train_target required, derive _VALID_TRAIN_TARGETS from TrainTarget enum - Add builder tests for old-recipe inference paths
| if not response_template.strip() or not end_of_turn_template.strip(): | ||
| raise ValueError(_FALLBACK_MSG) |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
| debug=debug, | ||
| ignore_index=( | ||
| label_ignore_index if label_ignore_index is not None else -100 | ||
| ), | ||
| **kwargs, | ||
| ) | ||
| raise ValueError(f"Unknown data collator name: '{collator_name}'") |
There was a problem hiding this comment.
Bug: The TextCompletionsCollatorWithPadding constructor can receive ignore_index both as an explicit argument and within **kwargs, causing a TypeError if a user customizes it.
Severity: HIGH
Suggested Fix
Before calling the TextCompletionsCollatorWithPadding constructor, pop ignore_index from the kwargs dictionary. Use the popped value to determine the final ignore_index value to be passed, preventing the argument duplication.
Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.
Location: src/oumi/builders/collators.py#L243-L249
Potential issue: When building the `text_completions_only_with_padding` collator in
`build_data_collator`, the `ignore_index` parameter is passed explicitly while also
being potentially present in the `**kwargs` passed to the
`TextCompletionsCollatorWithPadding` constructor. If a user specifies `ignore_index` in
their `collator_kwargs` configuration, this results in a `TypeError` because the
`__init__` method receives multiple values for the same keyword argument, causing a
runtime crash during training setup.
| Raises: | ||
| ValueError: If templates cannot be extracted. | ||
| """ | ||
| msgs = [ |
There was a problem hiding this comment.
This is really clever, but we should maybe include a system instruction here?
There was a problem hiding this comment.
Like modify it to extract system instruction? We are current no longer relying on system instruction for masking
There was a problem hiding this comment.
No I mean ensure that when we extract the different template portions, we do so using an example conversation that also includes a system instruction
Or rather, perhaps do it twice, once on a conversation with a system instruction, and once on a conversation without.
Why? Because I'm a bit concerned of the scenario where this logic doesn't use system instructions when extracting, but the user does in their data, and somehow the resulting templates we extract using this methodology don't wind up working when the data includes the presence of a system instruction for some reason.
| # Locate boundaries around the second turn pair | ||
| # to avoid system-prompt effects on the first turn. | ||
| try: | ||
| a1 = rendered.index(_SENTINEL_ASST) |
There was a problem hiding this comment.
nit: first_asst_start, and add _start to second turns
| except ValueError: | ||
| raise ValueError(_FALLBACK_MSG) | ||
|
|
||
| # End-of-turn: common token-ID prefix of the two strings that |
There was a problem hiding this comment.
extract this to its own method
| assert isinstance(_eot_decoded, str) | ||
| end_of_turn_template = _eot_decoded | ||
|
|
||
| # Response template: strip the EOT prefix to get just the assistant header. |
There was a problem hiding this comment.
Same here, extract to its own method
| if not response_template.strip() or not end_of_turn_template.strip(): | ||
| raise ValueError(_FALLBACK_MSG) | ||
|
|
||
| # Qwen3 and similar reasoning models inject <think>...</think> into |
There was a problem hiding this comment.
This description scares me, I wonder if we could have a better workaround or a louder error
| from oumi.core.configs.params.data_params import TrainTarget | ||
|
|
||
|
|
||
| class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): |
There was a problem hiding this comment.
It's not clear to me what this collator does differently than the other one.
There was a problem hiding this comment.
Previously, it would instruction_template and response_template pairs then mask everything in between (see L324). However, there isn't an instruction_template before the tool result, so tool result content does not get masked properly. Some models (like Qwen) accidentally mask it correctly because their chat template stores tool result under role.user. But for most models (like Llama etc), tool result is stored under a role.tool with a separate tag so masking is incorrectly applied for the reasons above. The new approach basically masks everything then only unmask portions between the response_template and end_of_turn_template. This is more robust to the different roles and specifics of how tool results are formatted and stored by different chat templates.
lefft
left a comment
There was a problem hiding this comment.
Can we be sure to check whether template detection and collation happens properly when user or assistant turns begin with (one or more) leading \ns? We had issues with this breaking tokenization and leading to records being skipped in enterprise, would like to understand whether these changes will help, harm, or not impact that issue.
| ) | ||
|
|
||
|
|
||
| def _resolve_collator_templates( |
There was a problem hiding this comment.
Has this been tested with gpt-oss models? They use a non-standard system for marking boundaries within responses, so wondering if boundary detection will work as it does for Qwen/Llama style templates.
| collator_kwargs["train_target"] = "all_assistant_turns" | ||
| elif has_inst: | ||
| warnings.warn( | ||
| "Instruction-based masking is deprecated.\n" |
There was a problem hiding this comment.
Is there a plan/intention to eliminate support for instruction-based masking altogether, or will we just keep it in a "deprecated" state?
Want to make sure we retain compatibility with existing enterprise configs (or verify that moving to this style doesn't change training dynamics before adapting enterprise configs).
Description
Context
Currently, the completions only collator does not consistently mask tool responses (Role.TOOL) correctly. For some chat templates (like Qwen), it accidentally gets it correct, while masking is consistently wrongly applied for llama chat templates. This PR enables the masking to be correctly applied for all chat templates, supporting all four roles (system, user, assistant and tool).
Without
instruction_template, tool results are masked correctly. But withinstruction_template, things get weird and some tool results are unmasked. Omittinginstruction_templatecauses the collator to find the lastresponse_templateand masks everything before it (only the final assistant turn trains). Withinstruction_template, the collator searches forinstruction_templateandresponse_templatepairs and masks everything in between.Case 1: without
instruction_templateWithout inst template, collator masks everything before the last resp template. So loss only sees
Case 2: with
instruction_templateWith
instruction_template, it masks everything between a inst and resp template. Since there isn't an inst template before the tool result, it does not get masked properly.In case 2, the loss sees tool result, which is incorrect.
To confirm my hypothesis of no inst template before the tool result being the cause, I experimented with adding a user turn before the tool call and masking works correctly in that case.
ToolAwareCompletionsCollator
With the new
ToolAwareCompletionsCollator, tool results are masked properlyWithout

instruction_templateWith

instruction_templateChange 1: Extend
DataCollatorForCompletionOnlyLMwith span-based maskingExtend
DataCollatorForCompletionOnlyLMwith span-based masking (detecting assistant turn boundaries via response + end-of-turn tokens) beyond instruction based masking (matching instruction/response string pairs) supported currently.Change 2:
train_targetenum for explicit control of which part of response to train onassistant_turn: Masks everything, then unmasks each assistant response bounded byresponse_template..end_of_turn_template(inclusive of EOT).final_assistant_turn: Masks all tokens before the lastresponse_templateoccurrence.When
train_targetis set, no matter what templates have been provided,train_targethas final authority on which combination is used. Whentrain_targetis not set, it is inferred from the templates provided. Eventually, we want users to specifytrain_targetinstead of relying on combinations of templates to control behavior. However, since we have many configs in production that do not usetrain_targetand set templates explicitly, this shall be the interim solution to ensure old configs do not break. Existing configs usingcollator_kwargswithinstruction_template+response_templatecontinue to work via the legacy path (with a deprecation warning). New configs should usetrain_targetexplicitly. Users can settrain_targetand override the templates manually by settingcollator_kwargs.Before — users must provide model-specific token strings:
After — users express intent, templates auto-resolve from tokenizer:
Change 3: Auto detection of collator templates
Instead of fragile vocab and collator template matching (requires us to maintain collator templates for all popular models), shift to a more robust approach: apply the chat template to a known test conversation, then finds the assistant boundary strings in the rendered output.
Migration and deprecation considerations:
collator_name-> removed support for this with the removal of the default llama template fallback. This should be ok as no configs in production specifycollator_namewithoutcollator_kwargs. Only less than 10 configs in OSS fall in this edgecase and they are being updated in this PR.Open issues for next time
train_targetintocollator_kwargs-> I considered putting train_target inside collator_kwargs instead of making it a separate field, but collator_kwargs is an opaque dict. Users wouldn't know train_target exists unless they read the docs. A top-level field makes it more visible. This will eventually be solved when we shift toCollatorParam.CollatorParams: This would mean changing every YAML from:to
This is much cleaner but its a substantial refactor. In order to manage the scope of this PR, this shall be deferred to a future PR
build_collator_from_configbuilds a single collator from the training split's settings, which is reused for all splits (train, validation, test). Validation/test split-specific collator settings (e.g. a differenttrain_target) are silently ignored. This is pre-existing behavior not introduced by this PR, but worth noting as a known limitation. Fixing it would require refactoring the training loop to build per-split collators.Related issues
N/A
Before submitting
Reviewers
At least one review from a member of
oumi-ai/oumi-staffis required.