-
Notifications
You must be signed in to change notification settings - Fork 751
Extend DataCollatorForCompletionOnlyLM to support correct tool result masking #2369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 27 commits
6726786
1c2f842
190b909
2948700
ec9d609
e9ee991
ab6e081
4c081d7
6682e56
3b92ca5
e67c3db
71b38ee
df0d473
414dd4c
33a5da8
b4583e4
c6744bb
780606c
b14ab66
e56f872
33c768e
0136af9
928c274
fbc7087
d5d93f4
a06055f
0cdb2bc
1c1dadf
e30dc99
1a8e1d2
48b65a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from collections.abc import Callable | ||
|
|
||
| import oumi.core.constants as constants | ||
|
|
@@ -27,11 +28,101 @@ | |
| from oumi.core.configs.internal.supported_models import ( | ||
| find_internal_model_config, | ||
| ) | ||
| from oumi.core.configs.params.data_params import TrainTarget | ||
| from oumi.core.tokenizers.base_tokenizer import BaseTokenizer | ||
| from oumi.utils.logging import logger | ||
|
|
||
| # This is used to set the max input length for a model with infinite size input | ||
| _VERY_LARGE_INTEGER = int(1e30) | ||
| _SENTINEL_USER = "<<__U__>>" | ||
| _SENTINEL_ASST = "<<__A__>>" | ||
| _FALLBACK_MSG = ( | ||
| "Cannot auto-detect collator templates from the chat template. " | ||
| "Provide response_template (and end_of_turn_template for " | ||
| "all_assistant_turns) via collator_kwargs." | ||
| ) | ||
|
|
||
|
|
||
| def _resolve_collator_templates( | ||
| tokenizer: "BaseTokenizer", | ||
| ) -> tuple[str, str]: | ||
| """Auto-detect response_template and end_of_turn_template. | ||
|
|
||
| Applies the chat template to a known test conversation, then finds | ||
| the assistant boundary strings in the rendered output. | ||
|
|
||
| Returns: | ||
| (response_template, end_of_turn_template) | ||
|
|
||
| Raises: | ||
| ValueError: If templates cannot be extracted. | ||
| """ | ||
| msgs = [ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really clever, but we should maybe include a system instruction here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like modify it to extract system instruction? We are current no longer relying on system instruction for masking
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No I mean ensure that when we extract the different template portions, we do so using an example conversation that also includes a system instruction Or rather, perhaps do it twice, once on a conversation with a system instruction, and once on a conversation without. Why? Because I'm a bit concerned of the scenario where this logic doesn't use system instructions when extracting, but the user does in their data, and somehow the resulting templates we extract using this methodology don't wind up working when the data includes the presence of a system instruction for some reason. |
||
| {"role": "user", "content": _SENTINEL_USER}, | ||
| {"role": "assistant", "content": _SENTINEL_ASST}, | ||
| {"role": "user", "content": _SENTINEL_USER}, | ||
| {"role": "assistant", "content": _SENTINEL_ASST}, | ||
| ] | ||
|
|
||
| try: | ||
| rendered = tokenizer.apply_chat_template( | ||
| msgs, tokenize=False, add_generation_prompt=False | ||
| ) | ||
| except Exception as exc: | ||
| raise ValueError(_FALLBACK_MSG) from exc | ||
|
|
||
| if not isinstance(rendered, str): | ||
| raise ValueError(_FALLBACK_MSG) | ||
|
|
||
| # Locate boundaries around the second turn pair | ||
| # to avoid system-prompt effects on the first turn. | ||
| try: | ||
| a1 = rendered.index(_SENTINEL_ASST) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: first_asst_start, and add _start to second turns |
||
| first_asst_end = a1 + len(_SENTINEL_ASST) | ||
| second_user = rendered.index(_SENTINEL_USER, first_asst_end) | ||
| second_user_end = second_user + len(_SENTINEL_USER) | ||
| second_asst = rendered.index(_SENTINEL_ASST, second_user_end) | ||
| second_asst_end = second_asst + len(_SENTINEL_ASST) | ||
| except ValueError: | ||
| raise ValueError(_FALLBACK_MSG) | ||
|
|
||
| # End-of-turn: common token-ID prefix of the two strings that | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extract this to its own method |
||
| # follow assistant content (mid-conversation vs. end-of-sequence). | ||
| after_ids = tokenizer.encode(rendered[second_asst_end:], add_special_tokens=False) | ||
| between_ids = tokenizer.encode( | ||
| rendered[first_asst_end:second_user], add_special_tokens=False | ||
| ) | ||
| eot_len = 0 | ||
| for a, b in zip(after_ids, between_ids): | ||
| if a != b: | ||
| break | ||
| eot_len += 1 | ||
| eot_ids = after_ids[:eot_len] | ||
| end_of_turn_template: str = tokenizer.decode(eot_ids, skip_special_tokens=False) | ||
|
shanghongsim marked this conversation as resolved.
Outdated
|
||
|
|
||
|
shanghongsim marked this conversation as resolved.
|
||
| # Response template: strip the EOT prefix to get just the assistant header. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, extract to its own method |
||
| resp_ids = tokenizer.encode( | ||
| rendered[second_user_end:second_asst], add_special_tokens=False | ||
| ) | ||
| if eot_len > 0 and resp_ids[:eot_len] == eot_ids: | ||
| resp_ids = resp_ids[eot_len:] | ||
| response_template: str = tokenizer.decode(resp_ids, skip_special_tokens=False) | ||
|
|
||
| if not response_template.strip(): | ||
| raise ValueError(_FALLBACK_MSG) | ||
|
|
||
| # Qwen3 and similar reasoning models inject <think>...</think> into | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This description scares me, I wonder if we could have a better workaround or a louder error |
||
| # every assistant turn via their chat template. If training data was | ||
| # formatted without thinking tokens the response_template won't match | ||
| # and every example will be silently masked. | ||
| if "<think>" in response_template: | ||
| logger.warning( | ||
| "The extracted response_template contains <think> tokens " | ||
| "(from the model's chat template). If you're training without " | ||
| "thinking tokens, use collator_kwargs to specify " | ||
| "response_template manually." | ||
| ) | ||
|
|
||
| return response_template, end_of_turn_template | ||
|
|
||
|
|
||
| def build_data_collator( | ||
|
|
@@ -51,7 +142,8 @@ def build_data_collator( | |
|
|
||
| - "text_with_padding": Uses `TextCollatorWithPadding`. | ||
| - "text_completions_only_with_padding": Uses | ||
| `TextCompletionsCollatorWithPadding`. | ||
| `TextCompletionsCollatorWithPadding`. Supports optional | ||
| ``end_of_turn_template`` for tool-aware span-based masking. | ||
| - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. | ||
| - "vision_language_sft": Uses `VisionLanguageSftCollator`. | ||
|
|
||
|
|
@@ -126,27 +218,20 @@ def build_data_collator( | |
| **kwargs, | ||
| ) | ||
| elif collator_name == "text_completions_only_with_padding": | ||
| # Extract instruction and response templates from kwargs if provided | ||
| instruction_template = kwargs.pop("instruction_template", None) | ||
| response_template = kwargs.pop("response_template", None) | ||
|
|
||
| # Default to Llama-style templates if not provided | ||
| instruction_prefix = ( | ||
| instruction_template | ||
| if instruction_template | ||
| else "<|start_header_id|>user<|end_header_id|>\n\n" | ||
| ) | ||
| response_prefix = ( | ||
| response_template | ||
| if response_template | ||
| else "<|start_header_id|>assistant<|end_header_id|>\n\n" | ||
| ) | ||
| if not kwargs.get("response_template"): | ||
| raise ValueError( | ||
| "'text_completions_only_with_padding' requires a " | ||
| "response_template. Either set train_target in your config " | ||
| "(which auto-resolves templates from the tokenizer) or " | ||
| "provide response_template via collator_kwargs." | ||
| ) | ||
|
|
||
| return TextCompletionsCollatorWithPadding( | ||
|
shanghongsim marked this conversation as resolved.
|
||
| tokenizer=tokenizer, | ||
| instruction_prefix=instruction_prefix, | ||
| response_prefix=response_prefix, | ||
| debug=debug, | ||
| ignore_index=( | ||
| label_ignore_index if label_ignore_index is not None else -100 | ||
| ), | ||
| **kwargs, | ||
| ) | ||
| raise ValueError(f"Unknown data collator name: '{collator_name}'") | ||
|
Comment on lines
243
to
249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: The Suggested FixBefore calling the Prompt for AI Agent |
||
|
|
@@ -206,9 +291,55 @@ def build_collator_from_config( | |
| "trust_remote_code", config.model.trust_remote_code | ||
| ) | ||
|
|
||
| # Merge collator_kwargs from config with the existing kwargs | ||
| # Config kwargs take precedence over automatically determined kwargs | ||
| # --- Resolve train_target and templates --- | ||
| config_collator_kwargs = train_split.collator_kwargs or {} | ||
|
|
||
| if collator_name == "text_completions_only_with_padding": | ||
| if train_split.train_target is not None: | ||
| # Path 1: train_target is set, auto-detect templates from | ||
| # the tokenizer's chat template. Falls back to user-provided | ||
| # response_template in collator_kwargs if auto-detection fails. | ||
| collator_kwargs["train_target"] = train_split.train_target.value | ||
|
|
||
| try: | ||
| response_template, end_of_turn_template = _resolve_collator_templates( | ||
| tokenizer | ||
| ) | ||
| collator_kwargs["response_template"] = response_template | ||
| if train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS: | ||
| collator_kwargs["end_of_turn_template"] = end_of_turn_template | ||
| except ValueError: | ||
| if config_collator_kwargs.get("response_template") is None: | ||
| raise | ||
|
shanghongsim marked this conversation as resolved.
|
||
|
|
||
| elif config_collator_kwargs.get("response_template") is not None: | ||
| # Path 2: train_target not set, templates provided manually | ||
| # via collator_kwargs. Infer train_target from which templates | ||
| # are present. | ||
| has_eot = config_collator_kwargs.get("end_of_turn_template") is not None | ||
| has_inst = config_collator_kwargs.get("instruction_template") is not None | ||
| if has_eot: | ||
| collator_kwargs["train_target"] = "all_assistant_turns" | ||
| elif has_inst: | ||
| warnings.warn( | ||
| "Instruction-based masking is deprecated. " | ||
| "Use train_target='all_assistant_turns' with " | ||
| "end_of_turn_template for multi-turn conversations, " | ||
| "or train_target='final_assistant_turn' " | ||
| "for single-turn completions.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| collator_kwargs["train_target"] = "_legacy_instruction_response" | ||
| else: | ||
| collator_kwargs["train_target"] = "final_assistant_turn" | ||
| else: | ||
| raise ValueError( | ||
| "'text_completions_only_with_padding' requires either " | ||
| "train_target or response_template in collator_kwargs." | ||
| ) | ||
|
|
||
| # User-provided collator_kwargs override auto-resolved values | ||
| collator_kwargs.update(config_collator_kwargs) | ||
|
|
||
| return build_data_collator( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has this been tested with gpt-oss models? They use a non-standard system for marking boundaries within responses, so wondering if boundary detection will work as it does for Qwen/Llama style templates.