-
Notifications
You must be signed in to change notification settings - Fork 751
Add tool aware collator to mask tool response correctly #2356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
6b893dd
8df9efd
d3d2213
188c02b
73cb2d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -19,6 +19,9 @@ | |||
| from oumi.core.collators.text_completions_collator_with_padding import ( | ||||
| TextCompletionsCollatorWithPadding, | ||||
| ) | ||||
| from oumi.core.collators.tool_aware_completions_collator import ( | ||||
| ToolAwareCompletionsCollator, | ||||
| ) | ||||
| from oumi.core.collators.vision_language_collator_with_padding import ( | ||||
| VisionLanguageCollatorWithPadding, | ||||
| ) | ||||
|
|
@@ -52,6 +55,8 @@ def build_data_collator( | |||
| - "text_with_padding": Uses `TextCollatorWithPadding`. | ||||
| - "text_completions_only_with_padding": Uses | ||||
| `TextCompletionsCollatorWithPadding`. | ||||
| - "tool_aware_completions_only": Uses `ToolAwareCompletionsCollator`. | ||||
| Correctly masks tool results in tool-calling conversations. | ||||
| - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. | ||||
| - "vision_language_sft": Uses `VisionLanguageSftCollator`. | ||||
|
|
||||
|
|
@@ -149,6 +154,31 @@ def build_data_collator( | |||
| debug=debug, | ||||
| **kwargs, | ||||
| ) | ||||
| elif collator_name == "tool_aware_completions_only": | ||||
| response_template = kwargs.pop("response_template", None) | ||||
| end_of_turn_template = kwargs.pop("end_of_turn_template", None) | ||||
| mask_tool_calls = kwargs.pop("mask_tool_calls", False) | ||||
| tool_call_start_template = kwargs.pop("tool_call_start_template", None) | ||||
|
|
||||
| if not response_template: | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remember correctly we made a collator that can identify this itself?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on my understanding and research, there isn't a collator in the codebase that auto-identifies the response template. They all need it provided explicitly or fall back to Llama defaults. oumi/src/oumi/builders/collators.py Line 133 in 81611e4
|
||||
| raise ValueError( | ||||
| f"'response_template' is required for '{collator_name}'. " | ||||
| "Set it in collator_kwargs (e.g. '<|im_start|>assistant\\n')." | ||||
| ) | ||||
| if not end_of_turn_template: | ||||
| raise ValueError( | ||||
| f"'end_of_turn_template' is required for '{collator_name}'. " | ||||
| "Set it in collator_kwargs (e.g. '<|im_end|>')." | ||||
| ) | ||||
|
|
||||
| return ToolAwareCompletionsCollator( | ||||
| response_template=response_template, | ||||
| end_of_turn_template=end_of_turn_template, | ||||
| mask_tool_calls=mask_tool_calls, | ||||
| tool_call_start_template=tool_call_start_template, | ||||
| tokenizer=tokenizer, | ||||
| **kwargs, | ||||
| ) | ||||
|
shanghongsim marked this conversation as resolved.
Outdated
|
||||
| raise ValueError(f"Unknown data collator name: '{collator_name}'") | ||||
|
|
||||
|
|
||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| # Copyright 2025 - Oumi | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from typing import Any | ||
|
|
||
| from transformers.data.data_collator import DataCollatorForLanguageModeling | ||
|
|
||
|
|
||
| class ToolAwareCompletionsCollator(DataCollatorForLanguageModeling): | ||
| r"""Completion-only collator that correctly masks tool results and tool calls. | ||
|
|
||
| The standard ``DataCollatorForCompletionOnlyLM`` uses template matching to | ||
| decide which tokens to train on. It works well for simple user/assistant | ||
| conversations, but breaks when a third role (``tool``) sits between two | ||
| assistant turns: the tool-result span ends up unmasked because the collator | ||
| only knows about the user-role marker as an "instruction" boundary. | ||
|
|
||
| This collator takes a different approach: | ||
|
|
||
| 1. Start with **all labels masked** (-100). | ||
| 2. Find every assistant response span by locating ``response_template`` | ||
| tokens and scanning forward to the next ``end_of_turn_template``. | ||
| 3. **Unmask** those spans so the model trains on them. | ||
| 4. Optionally **re-mask** spans that contain tool-call content (controlled | ||
| by ``mask_tool_calls``), so the model only trains on plain-text replies. | ||
|
|
||
| Because the algorithm never relies on user/instruction markers, it handles | ||
| any number of tool turns, parallel tool calls, and multi-turn conversations | ||
| correctly. | ||
|
|
||
| Args: | ||
| response_template: String or token-ID list that marks the *start* of an | ||
| assistant response (e.g. ``"<|im_start|>assistant\n"`` for SmolLM2 | ||
| or ``"[/INST]"`` for Llama-2). | ||
| end_of_turn_template: String or token-ID list that marks the *end* of a | ||
| turn (e.g. ``"<|im_end|>"`` for SmolLM2 or ``"</s>"`` for Llama). | ||
| mask_tool_calls: When ``True``, assistant spans that contain | ||
| ``tool_call_start_template`` are re-masked. Set this to ``True`` | ||
| if you only want to train on plain-text final answers. Defaults to | ||
| ``False`` (train on all assistant output including tool calls). | ||
| tool_call_start_template: String or token-ID list that marks the start | ||
| of a tool-call block inside an assistant turn (e.g. | ||
| ``"<tool_call>"``). Required when ``mask_tool_calls=True``. | ||
| ignore_index: Value used for masked labels. Must match the | ||
| ``ignore_index`` of the loss function (default: -100). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| response_template: str | list[int], | ||
| end_of_turn_template: str | list[int], | ||
| *args, | ||
| mask_tool_calls: bool = False, | ||
| tool_call_start_template: str | list[int] | None = None, | ||
| ignore_index: int = -100, | ||
| mlm: bool = False, | ||
| **kwargs, | ||
| ): | ||
| """Initializes ToolAwareCompletionsCollator.""" | ||
| super().__init__(*args, mlm=mlm, **kwargs) | ||
| self.ignore_index = ignore_index | ||
|
|
||
| if isinstance(response_template, str): | ||
| self.response_token_ids: list[int] = self.tokenizer.encode( | ||
| response_template, add_special_tokens=False | ||
| ) | ||
| else: | ||
| self.response_token_ids = list(response_template) | ||
|
|
||
| if isinstance(end_of_turn_template, str): | ||
| self.end_of_turn_token_ids: list[int] = self.tokenizer.encode( | ||
| end_of_turn_template, add_special_tokens=False | ||
| ) | ||
| else: | ||
| self.end_of_turn_token_ids = list(end_of_turn_template) | ||
|
|
||
| self.mask_tool_calls = mask_tool_calls | ||
| self.tool_call_start_token_ids: list[int] | None = None | ||
| if mask_tool_calls: | ||
| if tool_call_start_template is None: | ||
| raise ValueError( | ||
| "tool_call_start_template must be provided " | ||
| "when mask_tool_calls=True" | ||
| ) | ||
| if isinstance(tool_call_start_template, str): | ||
| self.tool_call_start_token_ids = self.tokenizer.encode( | ||
| tool_call_start_template, add_special_tokens=False | ||
| ) | ||
| else: | ||
| self.tool_call_start_token_ids = list(tool_call_start_template) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Internal helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _find_pattern(self, seq: list[int], pattern: list[int]) -> list[int]: | ||
| """Return all start positions where *pattern* appears in *seq*.""" | ||
| plen = len(pattern) | ||
| if plen == 0: | ||
| return [] | ||
| first = pattern[0] | ||
| positions = [] | ||
| for i in range(len(seq) - plen + 1): | ||
| if seq[i] == first and seq[i : i + plen] == pattern: | ||
| positions.append(i) | ||
| return positions | ||
|
|
||
| def _span_contains( | ||
| self, seq: list[int], span_start: int, span_end: int, pattern: list[int] | ||
| ) -> bool: | ||
| """Return True if *pattern* appears anywhere in seq[span_start:span_end].""" | ||
| plen = len(pattern) | ||
| if plen == 0: | ||
| return False | ||
| first = pattern[0] | ||
| for i in range(span_start, span_end - plen + 1): | ||
| if seq[i] == first and seq[i : i + plen] == pattern: | ||
| return True | ||
| return False | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Main collation | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def torch_call( | ||
| self, examples: list[list[int] | Any | dict[str, Any]] | ||
| ) -> dict[str, Any]: | ||
| """Collates examples and applies tool-aware label masking.""" | ||
| # Let the base class handle padding and create labels = input_ids. | ||
| batch = super().torch_call(examples) | ||
|
|
||
| resp_len = len(self.response_token_ids) | ||
| pad_token_id = self.tokenizer.pad_token_id | ||
|
|
||
| for i in range(len(examples)): | ||
| # Step 1: mask everything. | ||
| batch["labels"][i, :] = self.ignore_index | ||
|
|
||
| seq: list[int] = batch["input_ids"][i].tolist() | ||
|
|
||
| # Compute the effective sequence length excluding trailing padding. | ||
| # This prevents false matches when end_of_turn_token_ids overlaps | ||
| # with the pad token (common: e.g. <|im_end|> = eos = pad). | ||
| if pad_token_id is not None: | ||
| n = len(seq) | ||
| while n > 0 and seq[n - 1] == pad_token_id: | ||
| n -= 1 | ||
| else: | ||
| n = len(seq) | ||
|
|
||
| # Step 2: find every assistant response start position | ||
| # (within the non-padded region only). | ||
| resp_positions = self._find_pattern(seq[:n], self.response_token_ids) | ||
|
|
||
| if len(resp_positions) == 0: | ||
| warnings.warn( | ||
| f"Could not find response template in the following instance: " | ||
| f"{self.tokenizer.decode(batch['input_ids'][i])}. " | ||
| "This instance will be ignored in loss calculation.", | ||
| UserWarning, | ||
| ) | ||
| continue | ||
|
|
||
| for resp_pos in resp_positions: | ||
| # Content starts right after the response_template tokens. | ||
| content_start = resp_pos + resp_len | ||
|
|
||
| # Step 3: find the next end_of_turn after content_start | ||
| # (within the non-padded region only). | ||
| eot_positions = self._find_pattern( | ||
| seq[content_start:n], self.end_of_turn_token_ids | ||
| ) | ||
| if eot_positions: | ||
| content_end = content_start + eot_positions[0] | ||
| else: | ||
| # No closing marker found — unmask to end of real content. | ||
| content_end = n | ||
|
|
||
| if content_start >= content_end: | ||
| continue | ||
|
|
||
| # Step 4: optionally skip tool-call spans. | ||
| if self.mask_tool_calls and self.tool_call_start_token_ids is not None: | ||
| if self._span_contains( | ||
| seq, | ||
| content_start, | ||
| content_end, | ||
| self.tool_call_start_token_ids, | ||
| ): | ||
| # Leave this span masked. | ||
| continue | ||
|
|
||
| # Step 5: unmask this assistant response span. | ||
| batch["labels"][i, content_start:content_end] = batch["input_ids"][ | ||
| i, content_start:content_end | ||
| ] | ||
|
|
||
| return batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like it inherits padding support from the base class, should we add "WithPadding" to its name? Or are we moving away from that naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify this class actually uses composition rather than inheritance (it wraps a transformers
DataCollatorForCompletionOnlyLMinstance internally). The padding behavior comes from that inner collator's chain. So the "WithPadding" in the name is describing what the collator does, not something it inherits. That said, I agree the naming could be better. @oelachqar was suggesting a refactor of all collators into a universal one with options (padding=True, completion_only=True) actually. Exploring that now to see if that can be done in this PR