diff --git a/configs/projects/coalm/405b_train.yaml b/configs/projects/coalm/405b_train.yaml index d58fcb1726..0071aa4fab 100644 --- a/configs/projects/coalm/405b_train.yaml +++ b/configs/projects/coalm/405b_train.yaml @@ -18,12 +18,14 @@ data: shuffle: True seed: 42 collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 validation: datasets: - dataset_name: "text_sft_jsonl" dataset_path: "/path/to/validation/dataset.jsonl" collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 training: diff --git a/configs/projects/coalm/70b_train.yaml b/configs/projects/coalm/70b_train.yaml index e61c0d6db1..9332039d51 100644 --- a/configs/projects/coalm/70b_train.yaml +++ b/configs/projects/coalm/70b_train.yaml @@ -18,12 +18,14 @@ data: shuffle: True seed: 42 collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 validation: datasets: - dataset_name: "text_sft_jsonl" dataset_path: "/path/to/validation/dataset.jsonl" collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 training: diff --git a/configs/projects/coalm/8b_train.yaml b/configs/projects/coalm/8b_train.yaml index 85bf4b49b2..4e6ceff6c4 100644 --- a/configs/projects/coalm/8b_train.yaml +++ b/configs/projects/coalm/8b_train.yaml @@ -18,12 +18,14 @@ data: shuffle: True seed: 42 collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 validation: datasets: - dataset_name: "text_sft_jsonl" dataset_path: "/path/to/validation/dataset.jsonl" collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 training: diff --git a/configs/projects/halloumi/8b_train.yaml b/configs/projects/halloumi/8b_train.yaml index a922b404cb..fe80de3bea 100644 --- a/configs/projects/halloumi/8b_train.yaml +++ b/configs/projects/halloumi/8b_train.yaml @@ -66,6 +66,7 @@ data: seed: 42 collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 validation: datasets: @@ -78,6 +79,7 @@ data: } collator_name: "text_completions_only_with_padding" + train_target: "ALL_ASSISTANT_TURNS" seed: 42 training: diff --git a/src/oumi/builders/collators.py b/src/oumi/builders/collators.py index 4656b29b97..376decda31 100644 --- a/src/oumi/builders/collators.py +++ b/src/oumi/builders/collators.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Callable import oumi.core.constants as constants @@ -27,11 +28,114 @@ from oumi.core.configs.internal.supported_models import ( find_internal_model_config, ) +from oumi.core.configs.params.data_params import TrainTarget from oumi.core.tokenizers.base_tokenizer import BaseTokenizer from oumi.utils.logging import logger -# This is used to set the max input length for a model with infinite size input _VERY_LARGE_INTEGER = int(1e30) +_SENTINEL_USER = "<<__U__>>" +_SENTINEL_ASST = "<<__A__>>" +_FIX_HINT = ( + "Fix: provide response_template (and end_of_turn_template for " + "all_assistant_turns) in collator_kwargs." +) + + +def _resolve_collator_templates( + tokenizer: "BaseTokenizer", +) -> tuple[str, str]: + """Auto-detect response_template and end_of_turn_template. + + Applies the chat template to a known test conversation, then finds + the assistant boundary strings in the rendered output. + + Returns: + (response_template, end_of_turn_template) + + Raises: + ValueError: If templates cannot be extracted. + """ + msgs = [ + {"role": "user", "content": _SENTINEL_USER}, + {"role": "assistant", "content": _SENTINEL_ASST}, + {"role": "user", "content": _SENTINEL_USER}, + {"role": "assistant", "content": _SENTINEL_ASST}, + ] + + try: + rendered = tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=False + ) + except Exception as exc: + raise ValueError( + f"Tokenizer has no chat template or it failed to render.\n{_FIX_HINT}" + ) from exc + + if not isinstance(rendered, str): + raise ValueError( + f"Chat template returned a non-string type ({type(rendered).__name__}).\n" + f"{_FIX_HINT}" + ) + + # Locate boundaries around the second turn pair + # to avoid system-prompt effects on the first turn. + try: + a1 = rendered.index(_SENTINEL_ASST) + first_asst_end = a1 + len(_SENTINEL_ASST) + second_user = rendered.index(_SENTINEL_USER, first_asst_end) + second_user_end = second_user + len(_SENTINEL_USER) + second_asst = rendered.index(_SENTINEL_ASST, second_user_end) + second_asst_end = second_asst + len(_SENTINEL_ASST) + except ValueError: + raise ValueError( + "Could not locate assistant turn boundaries in the rendered " + f"chat template.\n{_FIX_HINT}" + ) + + # End-of-turn: common token-ID prefix of the two strings that + # follow assistant content (mid-conversation vs. end-of-sequence). + after_ids = tokenizer.encode(rendered[second_asst_end:], add_special_tokens=False) + between_ids = tokenizer.encode( + rendered[first_asst_end:second_user], add_special_tokens=False + ) + eot_len = 0 + for a, b in zip(after_ids, between_ids): + if a != b: + break + eot_len += 1 + eot_ids = after_ids[:eot_len] + _eot_decoded = tokenizer.decode(eot_ids, skip_special_tokens=False) + assert isinstance(_eot_decoded, str) + end_of_turn_template = _eot_decoded + + # Response template: strip the EOT prefix to get just the assistant header. + resp_ids = tokenizer.encode( + rendered[second_user_end:second_asst], add_special_tokens=False + ) + if eot_len > 0 and resp_ids[:eot_len] == eot_ids: + resp_ids = resp_ids[eot_len:] + _resp_decoded = tokenizer.decode(resp_ids, skip_special_tokens=False) + assert isinstance(_resp_decoded, str) + response_template = _resp_decoded + + if not response_template.strip(): + raise ValueError(f"Extracted response_template is empty.\n{_FIX_HINT}") + if not end_of_turn_template.strip(): + raise ValueError(f"Extracted end_of_turn_template is empty.\n{_FIX_HINT}") + + # Qwen3 and similar reasoning models inject ... into + # every assistant turn via their chat template. If training data was + # formatted without thinking tokens the response_template won't match + # and every example will be silently masked. + if "" in response_template: + logger.warning( + "The extracted response_template contains tokens " + "(from the model's chat template). If you're training without " + "thinking tokens, use collator_kwargs to specify " + "response_template manually." + ) + + return response_template, end_of_turn_template def build_data_collator( @@ -51,7 +155,8 @@ def build_data_collator( - "text_with_padding": Uses `TextCollatorWithPadding`. - "text_completions_only_with_padding": Uses - `TextCompletionsCollatorWithPadding`. + `TextCompletionsCollatorWithPadding`. Supports optional + ``end_of_turn_template`` for tool-aware span-based masking. - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. - "vision_language_sft": Uses `VisionLanguageSftCollator`. @@ -126,27 +231,19 @@ def build_data_collator( **kwargs, ) elif collator_name == "text_completions_only_with_padding": - # Extract instruction and response templates from kwargs if provided - instruction_template = kwargs.pop("instruction_template", None) - response_template = kwargs.pop("response_template", None) - - # Default to Llama-style templates if not provided - instruction_prefix = ( - instruction_template - if instruction_template - else "<|start_header_id|>user<|end_header_id|>\n\n" - ) - response_prefix = ( - response_template - if response_template - else "<|start_header_id|>assistant<|end_header_id|>\n\n" - ) + if not kwargs.get("response_template"): + raise ValueError( + "'text_completions_only_with_padding' requires a response_template.\n" + "Fix: set train_target in your data config (auto-resolves templates " + "from the tokenizer), or provide response_template in collator_kwargs." + ) return TextCompletionsCollatorWithPadding( tokenizer=tokenizer, - instruction_prefix=instruction_prefix, - response_prefix=response_prefix, debug=debug, + ignore_index=( + label_ignore_index if label_ignore_index is not None else -100 + ), **kwargs, ) raise ValueError(f"Unknown data collator name: '{collator_name}'") @@ -206,9 +303,66 @@ def build_collator_from_config( "trust_remote_code", config.model.trust_remote_code ) - # Merge collator_kwargs from config with the existing kwargs - # Config kwargs take precedence over automatically determined kwargs + # --- Resolve train_target and templates --- config_collator_kwargs = train_split.collator_kwargs or {} + + if collator_name == "text_completions_only_with_padding": + if train_split.train_target is not None: + # Path 1: train_target is set, auto-detect templates from + # the tokenizer's chat template. Falls back to user-provided + # response_template in collator_kwargs if auto-detection fails. + collator_kwargs["train_target"] = train_split.train_target.value + + try: + response_template, end_of_turn_template = _resolve_collator_templates( + tokenizer + ) + collator_kwargs["response_template"] = response_template + if train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS: + collator_kwargs["end_of_turn_template"] = end_of_turn_template + except ValueError: + if config_collator_kwargs.get("response_template") is None: + raise + + if ( + train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS + and "end_of_turn_template" not in collator_kwargs + and config_collator_kwargs.get("end_of_turn_template") is None + ): + raise ValueError( + "train_target='all_assistant_turns' requires end_of_turn_template, " + "but auto-detection failed.\n" + "Fix: provide end_of_turn_template in collator_kwargs." + ) + + elif config_collator_kwargs.get("response_template") is not None: + # Path 2: train_target not set, templates provided manually + # via collator_kwargs. Infer train_target from which templates + # are present. + has_eot = config_collator_kwargs.get("end_of_turn_template") is not None + has_inst = config_collator_kwargs.get("instruction_template") is not None + if has_eot: + collator_kwargs["train_target"] = "all_assistant_turns" + elif has_inst: + warnings.warn( + "Instruction-based masking is deprecated.\n" + "Use train_target='all_assistant_turns'" + "or train_target='final_assistant_turn' instead.", + DeprecationWarning, + stacklevel=2, + ) + collator_kwargs["train_target"] = "_legacy_instruction_response" + else: + collator_kwargs["train_target"] = "final_assistant_turn" + else: + raise ValueError( + "'text_completions_only_with_padding' collator requires" + " configuration.\n" + "Fix: set train_target in your data config, " + "or provide response_template in collator_kwargs." + ) + + # User-provided collator_kwargs override auto-resolved values collator_kwargs.update(config_collator_kwargs) return build_data_collator( diff --git a/src/oumi/core/collators/text_completions_collator_with_padding.py b/src/oumi/core/collators/text_completions_collator_with_padding.py index e567e47be9..50c814f02f 100644 --- a/src/oumi/core/collators/text_completions_collator_with_padding.py +++ b/src/oumi/core/collators/text_completions_collator_with_padding.py @@ -27,22 +27,34 @@ class TextCompletionsCollatorWithPadding: def __init__( self, tokenizer: BaseTokenizer, - instruction_prefix: str, - response_prefix: str, + response_template: str, + train_target: str, + instruction_template: str | None = None, debug: bool = False, + end_of_turn_template: str | None = None, + ignore_index: int = -100, ): """Custom collator for text LLM training. Args: tokenizer: The tokenizer used for encoding the data. - instruction_prefix: The prefix marking the beginning of the user instruction. - response_prefix: The prefix marking the beginning of the assistant response. + response_template: String marking assistant response start. + instruction_template: String marking user instruction start. debug: If True, enables debug mode for logging. + train_target: Training target — ``"all_assistant_turns"`` + or ``"final_assistant_turn"``. + end_of_turn_template: String marking the end of a turn. + Required for ``all_assistant_turns``. + ignore_index: Value used for masked labels. Must match the ignore_index + of the loss function (default: -100). """ self._default_collator = DataCollatorForCompletionOnlyLM( tokenizer=tokenizer, - instruction_template=instruction_prefix, - response_template=response_prefix, + instruction_template=instruction_template, + response_template=response_template, + train_target=train_target, + end_of_turn_template=end_of_turn_template, + ignore_index=ignore_index, ) if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None: @@ -55,7 +67,7 @@ def _collate(self, inputs: list[Any]) -> dict[str, Any]: result = self._default_collator(inputs) return result - def __call__(self, batch) -> dict[str, Any]: + def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]: """Pads to the longest length present in the batch. Args: diff --git a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py index 2e47212dbb..26778ebe61 100644 --- a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py +++ b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py @@ -19,18 +19,64 @@ import torch from transformers.data.data_collator import DataCollatorForLanguageModeling +from oumi.core.configs.params.data_params import TrainTarget + class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): - """Data collator used for completion tasks. + """Data collator for completion-only training. + + Masks input labels so that the loss is only computed on specific + tokens (typically assistant responses), while ignoring other tokens + (system prompts, user messages, padding). + + The ``train_target`` parameter selects the training target: + + **``all_assistant_turns``**: + Span-based masking for multi-turn and tool-calling conversations. + Masks everything, then unmarks each assistant response span bounded + by ``response_template`` .. ``end_of_turn_template`` (inclusive of EOT). + Correctly handles interleaved tool results and parallel tool calls. - Copied from `trl`'s `DataCollatorForCompletionOnlyLM` class. + **``final_assistant_turn``**: + Masks all tokens before the *last* ``response_template`` occurrence. + Only the final assistant response is trained on. Suitable for + single-turn completions. + + Args: + response_template: String or token IDs marking the start of an + assistant response. Required for all modes. + instruction_template: String or token IDs marking the start of a + user instruction. Legacy — only used with the instruction+response + fallback path. + train_target: One of ``"all_assistant_turns"``, + ``"final_assistant_turn"``, ``"_legacy_instruction_response"``. + Resolved by the builder before construction. + end_of_turn_template: String or token IDs marking the end of a + conversational turn. Required for ``all_assistant_turns`` mode. + mlm: Whether to use masked language modeling. Default False. + ignore_index: Label value for masked tokens. Default -100. + padding_free: Remove padding and add position_ids. Default False. """ + _VALID_TRAIN_TARGETS = {t.value for t in TrainTarget} | { + "_legacy_instruction_response", + } + + def _tokenize_template(self, template: str | list[int] | None) -> list[int] | None: + """Encode a template string into token IDs, or pass through if already IDs.""" + if template is None: + return None + if isinstance(template, str): + return self.tokenizer.encode(template, add_special_tokens=False) + return list(template) + def __init__( self, response_template: str | list[int], instruction_template: str | list[int] | None = None, *args, + train_target: str, + end_of_turn_template: str | list[int] | None = None, mlm: bool = False, ignore_index: int = -100, padding_free: bool = False, @@ -39,26 +85,33 @@ def __init__( """Initializes the DataCollatorForCompletionOnlyLM.""" super().__init__(*args, mlm=mlm, **kwargs) + # Tokenize templates. self.instruction_template = instruction_template - if isinstance(instruction_template, str): - # The user provides a string, must tokenize - self.instruction_token_ids = self.tokenizer.encode( - self.instruction_template, # type: ignore - add_special_tokens=False, - ) - else: - # The user already provides the token ids - self.instruction_token_ids = instruction_template - + self.instruction_token_ids = self._tokenize_template(instruction_template) self.response_template = response_template - if isinstance(response_template, str): - # The user provides a string, must tokenize - self.response_token_ids = self.tokenizer.encode( - self.response_template, add_special_tokens=False + self.response_token_ids: list[int] = self._tokenize_template(response_template) # type: ignore[assignment] + self.end_of_turn_template = end_of_turn_template + self.end_of_turn_token_ids = self._tokenize_template(end_of_turn_template) + + if train_target not in self._VALID_TRAIN_TARGETS: + valid = sorted(self._VALID_TRAIN_TARGETS - {"_legacy_instruction_response"}) + raise ValueError( + f"Unknown train_target='{train_target}'. Must be one of: {valid}" ) - else: - # The user already provides the token ids - self.response_token_ids = response_template + self.train_target = train_target + + if self.train_target == "all_assistant_turns": + if end_of_turn_template is None: + raise ValueError( + "end_of_turn_template must be provided " + f"when train_target='{self.train_target}'" + ) + if self.train_target == "_legacy_instruction_response": + if instruction_template is None: + raise ValueError( + "instruction_template must be provided " + f"when train_target='{self.train_target}'" + ) if ( not self.mlm @@ -78,13 +131,102 @@ def __init__( self.ignore_index = ignore_index self.padding_free = padding_free + @staticmethod + def _find_pattern(seq: list[int], pattern: list[int]) -> list[int]: + """Return all start positions where *pattern* appears in *seq*.""" + plen = len(pattern) + if plen == 0: + return [] + first = pattern[0] + positions = [] + for i in range(len(seq) - plen + 1): + if seq[i] == first and seq[i : i + plen] == pattern: + positions.append(i) + return positions + + def _apply_span_masking( + self, batch: dict[str, Any], examples: list[list[int] | Any | dict[str, Any]] + ) -> None: + """Apply span-based masking for multi-turn conversations. + + Masks all labels, then unmarks assistant response spans bounded by + response_template and end_of_turn_template (inclusive — the EOT token + is unmasked so the model learns to produce it). + """ + resp_ids = self.response_token_ids + eot_ids = self.end_of_turn_token_ids + assert eot_ids is not None # Caller checks end_of_turn_template is not None + resp_len = len(resp_ids) + pad_token_id = self.tokenizer.pad_token_id + + for i in range(len(examples)): + # Step 1: mask everything. + batch["labels"][i, :] = self.ignore_index + + seq: list[int] = batch["input_ids"][i].tolist() + + # Compute effective sequence length excluding trailing padding. + # Prevents false matches when end_of_turn_token_ids overlaps + # with the pad token (common: e.g. <|im_end|> = eos = pad). + if pad_token_id is not None: + n = len(seq) + while n > 0 and seq[n - 1] == pad_token_id: + n -= 1 + else: + n = len(seq) + + # Step 2: find every assistant response start position. + resp_positions = self._find_pattern(seq[:n], resp_ids) + + if len(resp_positions) == 0: + warnings.warn( + f"Could not find response template in the following instance: " + f"{self.tokenizer.decode(batch['input_ids'][i])}. " + "This instance will be ignored in loss calculation.", + UserWarning, + ) + continue + + for resp_pos in resp_positions: + content_start = resp_pos + resp_len + + # Step 3: find the next end_of_turn after content_start. + eot_positions = self._find_pattern(seq[content_start:n], eot_ids) + if eot_positions: + content_end = content_start + eot_positions[0] + else: + content_end = n + + if content_start >= content_end: + continue + + # Step 4: unmask this assistant response span, including the + # end-of-turn token so the model learns when to stop. + if eot_positions: + eot_len = len(self.end_of_turn_token_ids) # type: ignore + unmask_end = content_end + eot_len + else: + # No EOT found — content_end == n (end of real content). + # Do NOT extend past n or we'd unmask into padding. + unmask_end = content_end + batch["labels"][i, content_start:unmask_end] = batch["input_ids"][ + i, content_start:unmask_end + ] + + # ------------------------------------------------------------------ + # Main collation + # ------------------------------------------------------------------ + def torch_call( self, examples: list[list[int] | Any | dict[str, Any]] ) -> dict[str, Any]: """Collates a list of examples into a batch.""" batch = super().torch_call(examples) - if self.instruction_template is None: + if self.train_target == "all_assistant_turns": + self._apply_span_masking(batch, examples) + elif self.train_target == "final_assistant_turn": + # Response-only: unmask only the final assistant response. for i in range(len(examples)): response_token_ids_start_idx = None diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py index 8ed6769f5d..f56cc9ffe7 100644 --- a/src/oumi/core/configs/__init__.py +++ b/src/oumi/core/configs/__init__.py @@ -94,6 +94,7 @@ DatasetSplit, DatasetSplitParams, MixtureStrategy, + TrainTarget, ) from oumi.core.configs.params.evaluation_params import ( EvaluationBackend, @@ -190,6 +191,7 @@ "LMHarnessTaskParams", "LoraWeightInitialization", "MixedPrecisionDtype", + "TrainTarget", "MixtureStrategy", "ModelParams", "PeftParams", diff --git a/src/oumi/core/configs/params/data_params.py b/src/oumi/core/configs/params/data_params.py index 30f11a39c2..ac74a08811 100644 --- a/src/oumi/core/configs/params/data_params.py +++ b/src/oumi/core/configs/params/data_params.py @@ -52,6 +52,28 @@ def get_literal_value(self) -> Literal["first_exhausted", "all_exhausted"]: raise ValueError("Unsupported value for MixtureStrategy") +class TrainTarget(str, Enum): + """Controls which tokens contribute to the loss during training. + + Used with the ``text_completions_only_with_padding`` collator to + select the training target. Template tokens are auto-resolved + from the tokenizer vocabulary. + + Members: + ALL_ASSISTANT_TURNS: Train on all assistant response turns including + tool calls. Uses span-based masking: system prompts, user + messages, and tool results are masked; everything between the + assistant header and the end-of-turn token (inclusive) is + unmasked. + FINAL_ASSISTANT_TURN: Train only on the final assistant response. + Masks all tokens before the last ``response_template`` + occurrence. Suitable for single-turn completions. + """ + + ALL_ASSISTANT_TURNS = "all_assistant_turns" + FINAL_ASSISTANT_TURN = "final_assistant_turn" + + @dataclass class DatasetParams(BaseParams): dataset_name: str = MISSING @@ -197,8 +219,13 @@ class DatasetSplitParams(BaseParams): - "text_with_padding": Dynamically pads the inputs received to the longest length. + - "text_completions_only_with_padding": Uses template matching to + mask non-assistant tokens. Works for simple user/assistant turns. + Supports optional ``end_of_turn_template`` in ``collator_kwargs`` + for span-based masking. - "vision_language_with_padding": Uses VisionLanguageCollator for image+text multi-modal data. + - "vision_language_sft": Uses VisionLanguageSftCollator. If None, then a default collator will be assigned. """ @@ -210,6 +237,16 @@ class DatasetSplitParams(BaseParams): and can be used to customize collator behavior beyond the default parameters. """ + train_target: TrainTarget | None = None + """High-level training target for ``text_completions_only_with_padding``. + + When set, the builder auto-detects ``response_template`` and + ``end_of_turn_template`` from the tokenizer's chat template. + Use ``collator_kwargs`` to override individual auto-resolved values. + + See :class:`TrainTarget` for available options. + """ + pack: bool = False """Whether to pack the text into constant-length chunks. @@ -266,6 +303,20 @@ class DatasetSplitParams(BaseParams): def __post_init__(self): """Verifies params.""" + # Convert string train_target to enum if needed + if isinstance(self.train_target, str): + self.train_target = TrainTarget(self.train_target) + + if ( + self.train_target is not None + and self.collator_name != "text_completions_only_with_padding" + ): + raise ValueError( + "`train_target` requires " + "collator_name='text_completions_only_with_padding', " + f"got '{self.collator_name}'." + ) + if any([dataset.mixture_proportion is not None for dataset in self.datasets]): if not all( [dataset.mixture_proportion is not None for dataset in self.datasets] diff --git a/tests/unit/builders/test_collators.py b/tests/unit/builders/test_collators.py index ca2ddb86de..62db4b1956 100644 --- a/tests/unit/builders/test_collators.py +++ b/tests/unit/builders/test_collators.py @@ -12,6 +12,7 @@ ModelParams, TrainingConfig, TrainingParams, + TrainTarget, ) @@ -249,3 +250,339 @@ def test_build_collator_from_config_collator_kwargs_override(mock_tokenizer): assert callable(collator) # Verify that the config kwargs override the model-determined kwargs assert collator._allow_multi_image_inputs is False + + +# --------------------------------------------------------------------------- +# TrainTarget / builder auto-detection tests +# --------------------------------------------------------------------------- + + +def _chatml_tokenizer(): + """Mock tokenizer that renders ChatML format.""" + tok = MagicMock() + tok.pad_token_id = 0 + tok.model_max_length = 2048 + + def _apply(messages, **kw): + out = "".join( + f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n" for m in messages + ) + if kw.get("add_generation_prompt"): + out += "<|im_start|>assistant\n" + return out + + tok.apply_chat_template = MagicMock(side_effect=_apply) + + # The production code encodes/decodes three substrings from the + # rendered template. Map each to stable token IDs so the + # common-prefix logic works. + _encode_map = { + "<|im_end|>\n": [101, 10], + "<|im_end|>\n<|im_start|>user\n": [101, 10, 100, 20], + "<|im_end|>\n<|im_start|>assistant\n": [101, 10, 100, 30], + "<|im_start|>assistant\n": [100, 30], + } + _decode_map = { + (101, 10): "<|im_end|>\n", + (100, 30): "<|im_start|>assistant\n", + } + tok.encode = MagicMock(side_effect=lambda text, **kw: _encode_map[text]) + tok.decode = MagicMock(side_effect=lambda ids, **kw: _decode_map[tuple(ids)]) + return tok + + +def _llama3_tokenizer(): + """Mock tokenizer that renders Llama-3 format.""" + tok = MagicMock() + tok.pad_token_id = 0 + tok.model_max_length = 2048 + + def _apply(messages, **kw): + parts = ["<|begin_of_text|>"] + for m in messages: + parts.append( + f"<|start_header_id|>{m['role']}<|end_header_id|>\n\n" + f"{m['content']}<|eot_id|>" + ) + if kw.get("add_generation_prompt"): + parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n") + return "".join(parts) + + tok.apply_chat_template = MagicMock(side_effect=_apply) + + _encode_map = { + "<|eot_id|>": [203], + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n": [ + 203, + 201, + 20, + 202, + 10, + ], + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n": [ + 203, + 201, + 30, + 202, + 10, + ], + "<|start_header_id|>assistant<|end_header_id|>\n\n": [201, 30, 202, 10], + } + _decode_map = { + (203,): "<|eot_id|>", + (201, 30, 202, 10): "<|start_header_id|>assistant<|end_header_id|>\n\n", + } + tok.encode = MagicMock(side_effect=lambda text, **kw: _encode_map[text]) + tok.decode = MagicMock(side_effect=lambda ids, **kw: _decode_map[tuple(ids)]) + return tok + + +def _unknown_tokenizer(): + """Mock tokenizer with no chat template.""" + tok = MagicMock() + tok.pad_token_id = 0 + tok.model_max_length = 2048 + tok.apply_chat_template = MagicMock( + side_effect=Exception("No chat template configured") + ) + return tok + + +def test_build_data_collator_text_completions_with_tool_kwargs(mock_tokenizer): + """Build completions collator with end_of_turn_template + custom ignore index.""" + collator = build_data_collator( + "text_completions_only_with_padding", + mock_tokenizer, + max_length=512, + label_ignore_index=-200, + response_template="<|assistant|>", + end_of_turn_template="<|end|>", + train_target="all_assistant_turns", + ) + assert collator is not None + assert callable(collator) + inner = collator._default_collator + assert inner.ignore_index == -200 + + +def test_train_target_all_assistant_turns(): + """ChatML auto-detection with ALL_ASSISTANT_TURNS train target.""" + tok = _chatml_tokenizer() + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + train_target=TrainTarget.ALL_ASSISTANT_TURNS, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=tok) + assert collator is not None + inner = collator._default_collator + assert inner.response_template == "<|im_start|>assistant\n" + + +def test_train_target_final_assistant_turn(): + """ChatML auto-detection with FINAL_ASSISTANT_TURN train target.""" + tok = _chatml_tokenizer() + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + train_target=TrainTarget.FINAL_ASSISTANT_TURN, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=tok) + assert collator is not None + inner = collator._default_collator + assert inner.response_template == "<|im_start|>assistant\n" + + +def test_train_target_llama3(): + """Llama-3 auto-detection with ALL_ASSISTANT_TURNS train target.""" + tok = _llama3_tokenizer() + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + train_target=TrainTarget.ALL_ASSISTANT_TURNS, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=tok) + assert collator is not None + inner = collator._default_collator + assert ( + inner.response_template == "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + +def test_train_target_unknown_tokenizer(): + """Error when tokenizer vocab does not match any known chat format.""" + tok = _unknown_tokenizer() + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + train_target=TrainTarget.ALL_ASSISTANT_TURNS, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + with pytest.raises(ValueError, match="no chat template"): + build_collator_from_config(config, tokenizer=tok) + + +def test_train_target_with_collator_kwargs_override(): + """collator_kwargs overrides auto-resolved templates when train_target is set.""" + tok = _chatml_tokenizer() + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + train_target=TrainTarget.ALL_ASSISTANT_TURNS, + collator_kwargs={"response_template": "<|im_end|>\n"}, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=tok) + assert collator is not None + inner = collator._default_collator + # Auto-resolved would be "<|im_start|>assistant\n"; user override wins + assert inner.response_template == "<|im_end|>\n" + + +def test_train_target_on_wrong_collator(): + """train_target is only valid for text_completions_only_with_padding.""" + with pytest.raises(ValueError, match="train_target.*requires"): + DatasetSplitParams( + collator_name="text_with_padding", + train_target=TrainTarget.ALL_ASSISTANT_TURNS, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + + +def test_legacy_instruction_template_backward_compat(mock_tokenizer): + """Legacy path: instruction_template + response_template → _legacy + warning.""" + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + collator_kwargs={ + "response_template": "<|assistant|>", + "instruction_template": "<|user|>", + }, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + with pytest.warns( + DeprecationWarning, match="Instruction-based masking is deprecated" + ): + collator = build_collator_from_config(config, tokenizer=mock_tokenizer) + assert collator is not None + inner = collator._default_collator + assert inner.response_template == "<|assistant|>" + assert inner.instruction_template == "<|user|>" + assert inner.train_target == "_legacy_instruction_response" + + +def test_bare_collator_name_raises_without_templates(mock_tokenizer): + """Bare collator_name without kwargs or train_target raises an error.""" + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + with pytest.raises(ValueError, match="response_template"): + build_collator_from_config(config, tokenizer=mock_tokenizer) + + +def test_old_recipe_response_only_sets_final(mock_tokenizer): + """Old recipe: response_template only → final_assistant_turn.""" + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + collator_kwargs={ + "response_template": "<|assistant|>", + }, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=mock_tokenizer) + assert collator is not None + assert collator._default_collator.train_target == "final_assistant_turn" + + +def test_old_recipe_eot_sets_all_assistant(mock_tokenizer): + """Old recipe: response_template + end_of_turn_template → all_assistant_turns.""" + config = TrainingConfig( + data=DataParams( + train=DatasetSplitParams( + collator_name="text_completions_only_with_padding", + collator_kwargs={ + "response_template": "<|assistant|>", + "end_of_turn_template": "<|end|>", + }, + datasets=[DatasetParams(dataset_name="dummy", split="train")], + ) + ), + model=ModelParams( + model_name="MlpEncoder", + tokenizer_name="openai-community/gpt2", + model_max_length=512, + ), + ) + collator = build_collator_from_config(config, tokenizer=mock_tokenizer) + assert collator is not None + assert collator._default_collator.train_target == "all_assistant_turns" diff --git a/tests/unit/core/collators/test_text_completions_collator_with_padding.py b/tests/unit/core/collators/test_text_completions_collator_with_padding.py index f98b618ec2..258784d507 100644 --- a/tests/unit/core/collators/test_text_completions_collator_with_padding.py +++ b/tests/unit/core/collators/test_text_completions_collator_with_padding.py @@ -1,10 +1,12 @@ import functools +import warnings from unittest.mock import MagicMock import numpy as np import pytest import torch +import oumi.core.constants as constants from oumi.builders import build_tokenizer from oumi.core.collators.text_completions_collator_with_padding import ( TextCompletionsCollatorWithPadding, @@ -13,6 +15,15 @@ from oumi.core.tokenizers.base_tokenizer import BaseTokenizer from oumi.utils import logging +IGNORE = constants.LABEL_IGNORE_INDEX + +# Template strings for span-masking tests — chosen to be unambiguous in GPT-2's vocab. +_RESP_STR = " ASSISTANT_RESPONSE_START" +_EOT_STR = " TURN_ENDS_HERE" + +# Arbitrary token IDs used as "content" that must not appear in any template. +_SENTINELS = [601, 602, 603, 604, 605, 606, 607, 608] + @pytest.fixture def mock_tokenizer(): @@ -50,8 +61,9 @@ def test_success_basic(): collator = TextCompletionsCollatorWithPadding( tokenizer=tokenizer, - instruction_prefix=instruction_prefix, - response_prefix=response_prefix, + instruction_template=instruction_prefix, + response_template=response_prefix, + train_target="_legacy_instruction_response", ) assert callable(collator) @@ -174,8 +186,9 @@ def test_debug_logging(caplog): collator = TextCompletionsCollatorWithPadding( tokenizer=tokenizer, - instruction_prefix=instruction_prefix, - response_prefix=response_prefix, + instruction_template=instruction_prefix, + response_template=response_prefix, + train_target="_legacy_instruction_response", debug=True, ) assert callable(collator) @@ -238,3 +251,228 @@ def test_debug_logging(caplog): assert "'input_ids':" in log_text assert "'attention_mask':" in log_text assert "'labels':" in log_text + + +# =========================================================================== +# Span-based masking tests (tool-aware collation) +# =========================================================================== + + +@functools.cache +def get_template_token_ids() -> tuple[list[int], list[int]]: + """Return (resp_ids, eot_ids) encoded once and cached.""" + tokenizer, _ = create_test_tokenizer() + resp = tokenizer.encode(_RESP_STR, add_special_tokens=False) + eot = tokenizer.encode(_EOT_STR, add_special_tokens=False) + forbidden = set(resp) | set(eot) + for sentinel in _SENTINELS: + assert sentinel not in forbidden, ( + f"Sentinel {sentinel} collides with a template token ID. Adjust _SENTINELS." + ) + return resp, eot + + +def make_span_collator() -> TextCompletionsCollatorWithPadding: + tokenizer, _ = create_test_tokenizer() + return TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_template=_RESP_STR, + train_target="all_assistant_turns", + end_of_turn_template=_EOT_STR, + ) + + +def get_span_labels(collator, seq: list[int]) -> list[int]: + return collator([{"input_ids": seq}])["labels"][0].tolist() + + +def flat(*parts: list[int]) -> list[int]: + result = [] + for p in parts: + result.extend(p) + return result + + +# --------------------------------------------------------------------------- +# Single assistant turn +# --------------------------------------------------------------------------- + + +def test_span_single_turn_content_is_unmasked(): + resp, eot = get_template_token_ids() + prefix = [_SENTINELS[0], _SENTINELS[1]] + content = [_SENTINELS[2], _SENTINELS[3]] + seq = flat(prefix, resp, content, eot) + + labels = get_span_labels(make_span_collator(), seq) + + n_prefix = len(prefix) + len(resp) + assert all(v == IGNORE for v in labels[:n_prefix]) + assert labels[n_prefix : n_prefix + len(content)] == content + # EOT tokens are unmasked (model learns to produce the stop token) + assert labels[n_prefix + len(content) : n_prefix + len(content) + len(eot)] == eot + + +# --------------------------------------------------------------------------- +# Multiple assistant turns +# --------------------------------------------------------------------------- + + +def test_span_two_turns_both_unmasked(): + resp, eot = get_template_token_ids() + turn1 = [_SENTINELS[0], _SENTINELS[1]] + middle = [_SENTINELS[2]] + turn2 = [_SENTINELS[3], _SENTINELS[4]] + seq = flat(resp, turn1, eot, middle, resp, turn2, eot) + + labels = get_span_labels(make_span_collator(), seq) + + t1_start = len(resp) + t1_end = t1_start + len(turn1) + assert labels[t1_start:t1_end] == turn1 + + t2_start = t1_end + len(eot) + len(middle) + len(resp) + t2_end = t2_start + len(turn2) + assert labels[t2_start:t2_end] == turn2 + + +def test_span_content_between_turns_is_masked(): + resp, eot = get_template_token_ids() + turn1 = [_SENTINELS[0]] + between = [_SENTINELS[1], _SENTINELS[2]] + turn2 = [_SENTINELS[3]] + seq = flat(resp, turn1, eot, between, resp, turn2, eot) + + labels = get_span_labels(make_span_collator(), seq) + + between_start = len(resp) + len(turn1) + len(eot) + for i in range(len(between)): + assert labels[between_start + i] == IGNORE + + +def test_span_masking_requires_end_of_turn_template(): + tokenizer, _ = create_test_tokenizer() + with pytest.raises(ValueError, match="end_of_turn_template"): + TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_template=_RESP_STR, + train_target="all_assistant_turns", + end_of_turn_template=None, + ) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +def test_span_no_response_template_all_masked(): + seq = [_SENTINELS[0], _SENTINELS[1], _SENTINELS[2]] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + labels = get_span_labels(make_span_collator(), seq) + + assert all(v == IGNORE for v in labels) + assert any("response template" in str(w.message).lower() for w in caught) + + +def test_span_no_eot_unmasked_to_end_of_sequence(): + resp, _ = get_template_token_ids() + content = [_SENTINELS[0], _SENTINELS[1]] + seq = flat(resp, content) + + labels = get_span_labels(make_span_collator(), seq) + + assert labels[len(resp) :] == content + + +def test_span_empty_content_span(): + resp, eot = get_template_token_ids() + seq = flat(resp, eot) + + labels = get_span_labels(make_span_collator(), seq) + + assert all(v == IGNORE for v in labels) + + +def test_span_padding_matching_eot_does_not_false_match(): + """When pad_token_id matches the EOT token, padding must not be treated as + a real end-of-turn boundary.""" + tokenizer, pad_token_id = create_test_tokenizer() + resp, _ = get_template_token_ids() + + # Use pad_token_id itself as the EOT template — worst case scenario. + eot_ids = [pad_token_id] + content = [_SENTINELS[0], _SENTINELS[1]] + # Sequence: [RESP] content (no real EOT), then padding + seq = flat(resp, content) + [pad_token_id] * 5 + + collator = TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_template=_RESP_STR, + train_target="all_assistant_turns", + end_of_turn_template=str(tokenizer.decode(eot_ids)), + ) + batch = collator([{"input_ids": seq}]) + labels = batch["labels"][0].tolist() + + # Content should be unmasked — the padding should not act as an EOT. + content_start = len(resp) + assert labels[content_start : content_start + len(content)] == content + # Padding should be masked. + assert all(v == IGNORE for v in labels[content_start + len(content) :]) + + +# --------------------------------------------------------------------------- +# Batch processing +# --------------------------------------------------------------------------- + + +def test_span_batch_two_examples_processed_independently(): + resp, eot = get_template_token_ids() + _, pad_token_id = create_test_tokenizer() + content_a = [_SENTINELS[0], _SENTINELS[1]] + content_b = [_SENTINELS[2]] + seq_a = flat(resp, content_a, eot) + seq_b = flat(resp, content_b, eot) + + max_len = max(len(seq_a), len(seq_b)) + pad_a = [pad_token_id] * (max_len - len(seq_a)) + pad_b = [pad_token_id] * (max_len - len(seq_b)) + + collator = make_span_collator() + batch = collator([{"input_ids": seq_a + pad_a}, {"input_ids": seq_b + pad_b}]) + labels_a = batch["labels"][0].tolist() + labels_b = batch["labels"][1].tolist() + + assert labels_a[len(resp) : len(resp) + len(content_a)] == content_a + assert labels_b[len(resp) : len(resp) + len(content_b)] == content_b + + +def test_span_batch_bad_example_does_not_affect_others(): + resp, eot = get_template_token_ids() + _, pad_token_id = create_test_tokenizer() + good_seq = flat(resp, [_SENTINELS[0]], eot) + bad_seq = [_SENTINELS[1], _SENTINELS[2]] + + max_len = max(len(good_seq), len(bad_seq)) + pad_good = [pad_token_id] * (max_len - len(good_seq)) + pad_bad = [pad_token_id] * (max_len - len(bad_seq)) + + collator = make_span_collator() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + batch = collator( + [{"input_ids": good_seq + pad_good}, {"input_ids": bad_seq + pad_bad}] + ) + + assert batch["labels"][0].tolist()[len(resp)] == _SENTINELS[0] + assert all(v == IGNORE for v in batch["labels"][1].tolist()) + + +def test_span_labels_shape_matches_input_ids(): + resp, eot = get_template_token_ids() + seq = flat(resp, [_SENTINELS[0], _SENTINELS[1]], eot) + batch = make_span_collator()([{"input_ids": seq}]) + assert batch["labels"].shape == batch["input_ids"].shape diff --git a/tests/unit/core/datasets/test_base_sft_dataset.py b/tests/unit/core/datasets/test_base_sft_dataset.py index 5d51415fa0..74afa92bdf 100644 --- a/tests/unit/core/datasets/test_base_sft_dataset.py +++ b/tests/unit/core/datasets/test_base_sft_dataset.py @@ -26,8 +26,9 @@ def _get_hf_collator_result(conversation, tokenizer): collator = TextCompletionsCollatorWithPadding( tokenizer=tokenizer, - instruction_prefix=_INSTRUCTION_PREFIX, - response_prefix=_RESPONSE_PREFIX, + instruction_template=_INSTRUCTION_PREFIX, + response_template=_RESPONSE_PREFIX, + train_target="_legacy_instruction_response", ) return collator(batch)