Add span-based masking to DataCollatorForCompletionOnlyLM#2368
Closed
shanghongsim wants to merge 14 commits intomainfrom
Closed
Add span-based masking to DataCollatorForCompletionOnlyLM#2368shanghongsim wants to merge 14 commits intomainfrom
shanghongsim wants to merge 14 commits intomainfrom
Conversation
Merge tool-aware masking directly into the existing TRL-derived collator. Adds masking_method parameter with three strategies: assistant_turn, assistant_turn_no_tools, and final_assistant_turn. - Add _find_pattern, _span_contains, _apply_span_masking to DataCollatorForCompletionOnlyLM - Rename instruction_prefix/response_prefix to instruction_template/response_template - Add end_of_turn_template, tool_call_start_template, masking_method params - Update builder to pass new params and support span-based masking - Add comprehensive tests for all masking strategies
4 tasks
The wrapper never uses list[int] directly — it passes through to DataCollatorForCompletionOnlyLM. All callers pass strings. Update tests to pass string templates instead of token ID lists.
Raise ValueError at init time when masking_method is assistant_turn or assistant_turn_no_tools but end_of_turn_template is None. Previously this silently passed init and crashed with AssertionError on the first batch during training.
…class level Replace 4 repeated isinstance/encode blocks with a single helper method. Move _KNOWN_MASKING_METHODS from __init__ local to class-level constant.
Replace reference to mask_tool_calls=True with the correct usage: masking_method='assistant_turn_no_tools'.
Emit DeprecationWarning when the collator infers _legacy_instruction_response masking from the presence of instruction_template. Guides users toward masking_method.
When masking_method is not explicit and both end_of_turn_template and tool_call_start_template are present, infer assistant_turn_no_tools instead of assistant_turn.
Replace nested if/elif chain with a classmethod that validates or infers masking_method from template presence. Keeps __init__ focused on tokenization and validation.
8309130 to
4c081d7
Compare
Comment on lines
+102
to
+110
| if masking_method not in cls._KNOWN_MASKING_METHODS: | ||
| valid = sorted( | ||
| cls._KNOWN_MASKING_METHODS - {"_legacy_instruction_response"} | ||
| ) | ||
| raise ValueError( | ||
| f"Unknown masking_method='{masking_method}'. " | ||
| f"Must be one of: {valid}" | ||
| ) | ||
| return masking_method |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
Comment on lines
+94
to
+96
| Priority (first match wins): | ||
| 1. Explicit masking_method (validated) | ||
| 2. end_of_turn + tool_call_start → assistant_turn_no_tools |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
4 tasks
Drop the assistant_turn_no_tools masking mode and the tool_call_start_template parameter it depended on, simplifying the span-based masking to only support assistant_turn and final_assistant_turn strategies.
Remove hardcoded Llama-style instruction/response template defaults from build_data_collator — callers must now provide response_template explicitly. Inline the single-use _collate method in the wrapper.
Replace the classmethod + intermediate booleans with a simple if/elif chain directly in __init__. Rename _KNOWN_MASKING_METHODS to _VALID_MASKING_METHODS since the legacy path is internal.
Comment on lines
+130
to
+134
| if not kwargs.get("response_template"): | ||
| raise ValueError( | ||
| "'text_completions_only_with_padding' requires a " | ||
| "response_template. Provide it via collator_kwargs." | ||
| ) |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
…tant_turns Rename for clarity: the parameter selects what to train on, not what to mask. assistant_turn becomes all_assistant_turns to distinguish from final_assistant_turn.
86c00e7 to
df0d473
Compare
Comment on lines
+139
to
144
| ignore_index=( | ||
| label_ignore_index if label_ignore_index is not None else -100 | ||
| ), | ||
| **kwargs, | ||
| ) | ||
| raise ValueError(f"Unknown data collator name: '{collator_name}'") |
There was a problem hiding this comment.
Bug: Passing ignore_index in collator_kwargs will cause a TypeError due to receiving multiple values for the same keyword argument in the TextCompletionsCollatorWithPadding constructor.
Severity: MEDIUM
Suggested Fix
Before calling the TextCompletionsCollatorWithPadding constructor, remove the ignore_index key from the kwargs dictionary to prevent the argument collision. For example: kwargs.pop("ignore_index", None).
Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.
Location: src/oumi/builders/collators.py#L139-L144
Potential issue: The `build_data_collator` function calls the
`TextCompletionsCollatorWithPadding` constructor, passing `ignore_index` as an explicit
keyword argument while also unpacking `collator_kwargs` into the same call. If a user's
configuration includes `ignore_index` within `collator_kwargs`, the argument will be
passed twice. This will raise a `TypeError` because the constructor receives multiple
values for the `ignore_index` keyword argument. The code does not currently handle this
potential collision by removing the key from the `kwargs` dictionary before the call.
Contributor
Author
|
Merged into #2369 which now targets main directly. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Extend
DataCollatorForCompletionOnlyLMwith span-based masking (detecting assistant turn boundaries via response + end-of-turn tokens) beyond instruction based masking (matching instruction/response string pairs) supported currently. Additionally, allow explicit control of which part of response to train on usingtrain_targetparameter:assistant_turn: Masks everything, then unmasks each assistant response bounded byresponse_template..end_of_turn_template(inclusive of EOT).final_assistant_turn: Masks all tokens before the lastresponse_templateoccurrence.When
train_targetis not set, the mode is inferred from template presence.train_targetultimately has the final say on which template is used. Existing configs usingcollator_kwargswithinstruction_template+response_templatecontinue to work via the legacy path (with a deprecation warning). New configs should usetrain_targetexplicitly.Related issues
N/A
Part 1 of 2 — part 2 is #2369 (adds the
MaskingMethodenum for YAML config).Before submitting
Reviewers
At least one review from a member of
oumi-ai/oumi-staffis required.