diff --git a/configs/projects/coalm/405b_train.yaml b/configs/projects/coalm/405b_train.yaml
index d58fcb1726..0071aa4fab 100644
--- a/configs/projects/coalm/405b_train.yaml
+++ b/configs/projects/coalm/405b_train.yaml
@@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
training:
diff --git a/configs/projects/coalm/70b_train.yaml b/configs/projects/coalm/70b_train.yaml
index e61c0d6db1..9332039d51 100644
--- a/configs/projects/coalm/70b_train.yaml
+++ b/configs/projects/coalm/70b_train.yaml
@@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
training:
diff --git a/configs/projects/coalm/8b_train.yaml b/configs/projects/coalm/8b_train.yaml
index 85bf4b49b2..4e6ceff6c4 100644
--- a/configs/projects/coalm/8b_train.yaml
+++ b/configs/projects/coalm/8b_train.yaml
@@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
training:
diff --git a/configs/projects/halloumi/8b_train.yaml b/configs/projects/halloumi/8b_train.yaml
index a922b404cb..fe80de3bea 100644
--- a/configs/projects/halloumi/8b_train.yaml
+++ b/configs/projects/halloumi/8b_train.yaml
@@ -66,6 +66,7 @@ data:
seed: 42
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
validation:
datasets:
@@ -78,6 +79,7 @@ data:
}
collator_name: "text_completions_only_with_padding"
+ train_target: "ALL_ASSISTANT_TURNS"
seed: 42
training:
diff --git a/src/oumi/builders/collators.py b/src/oumi/builders/collators.py
index 4656b29b97..376decda31 100644
--- a/src/oumi/builders/collators.py
+++ b/src/oumi/builders/collators.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
from collections.abc import Callable
import oumi.core.constants as constants
@@ -27,11 +28,114 @@
from oumi.core.configs.internal.supported_models import (
find_internal_model_config,
)
+from oumi.core.configs.params.data_params import TrainTarget
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.logging import logger
-# This is used to set the max input length for a model with infinite size input
_VERY_LARGE_INTEGER = int(1e30)
+_SENTINEL_USER = "<<__U__>>"
+_SENTINEL_ASST = "<<__A__>>"
+_FIX_HINT = (
+ "Fix: provide response_template (and end_of_turn_template for "
+ "all_assistant_turns) in collator_kwargs."
+)
+
+
+def _resolve_collator_templates(
+ tokenizer: "BaseTokenizer",
+) -> tuple[str, str]:
+ """Auto-detect response_template and end_of_turn_template.
+
+ Applies the chat template to a known test conversation, then finds
+ the assistant boundary strings in the rendered output.
+
+ Returns:
+ (response_template, end_of_turn_template)
+
+ Raises:
+ ValueError: If templates cannot be extracted.
+ """
+ msgs = [
+ {"role": "user", "content": _SENTINEL_USER},
+ {"role": "assistant", "content": _SENTINEL_ASST},
+ {"role": "user", "content": _SENTINEL_USER},
+ {"role": "assistant", "content": _SENTINEL_ASST},
+ ]
+
+ try:
+ rendered = tokenizer.apply_chat_template(
+ msgs, tokenize=False, add_generation_prompt=False
+ )
+ except Exception as exc:
+ raise ValueError(
+ f"Tokenizer has no chat template or it failed to render.\n{_FIX_HINT}"
+ ) from exc
+
+ if not isinstance(rendered, str):
+ raise ValueError(
+ f"Chat template returned a non-string type ({type(rendered).__name__}).\n"
+ f"{_FIX_HINT}"
+ )
+
+ # Locate boundaries around the second turn pair
+ # to avoid system-prompt effects on the first turn.
+ try:
+ a1 = rendered.index(_SENTINEL_ASST)
+ first_asst_end = a1 + len(_SENTINEL_ASST)
+ second_user = rendered.index(_SENTINEL_USER, first_asst_end)
+ second_user_end = second_user + len(_SENTINEL_USER)
+ second_asst = rendered.index(_SENTINEL_ASST, second_user_end)
+ second_asst_end = second_asst + len(_SENTINEL_ASST)
+ except ValueError:
+ raise ValueError(
+ "Could not locate assistant turn boundaries in the rendered "
+ f"chat template.\n{_FIX_HINT}"
+ )
+
+ # End-of-turn: common token-ID prefix of the two strings that
+ # follow assistant content (mid-conversation vs. end-of-sequence).
+ after_ids = tokenizer.encode(rendered[second_asst_end:], add_special_tokens=False)
+ between_ids = tokenizer.encode(
+ rendered[first_asst_end:second_user], add_special_tokens=False
+ )
+ eot_len = 0
+ for a, b in zip(after_ids, between_ids):
+ if a != b:
+ break
+ eot_len += 1
+ eot_ids = after_ids[:eot_len]
+ _eot_decoded = tokenizer.decode(eot_ids, skip_special_tokens=False)
+ assert isinstance(_eot_decoded, str)
+ end_of_turn_template = _eot_decoded
+
+ # Response template: strip the EOT prefix to get just the assistant header.
+ resp_ids = tokenizer.encode(
+ rendered[second_user_end:second_asst], add_special_tokens=False
+ )
+ if eot_len > 0 and resp_ids[:eot_len] == eot_ids:
+ resp_ids = resp_ids[eot_len:]
+ _resp_decoded = tokenizer.decode(resp_ids, skip_special_tokens=False)
+ assert isinstance(_resp_decoded, str)
+ response_template = _resp_decoded
+
+ if not response_template.strip():
+ raise ValueError(f"Extracted response_template is empty.\n{_FIX_HINT}")
+ if not end_of_turn_template.strip():
+ raise ValueError(f"Extracted end_of_turn_template is empty.\n{_FIX_HINT}")
+
+ # Qwen3 and similar reasoning models inject ... into
+ # every assistant turn via their chat template. If training data was
+ # formatted without thinking tokens the response_template won't match
+ # and every example will be silently masked.
+ if "" in response_template:
+ logger.warning(
+ "The extracted response_template contains tokens "
+ "(from the model's chat template). If you're training without "
+ "thinking tokens, use collator_kwargs to specify "
+ "response_template manually."
+ )
+
+ return response_template, end_of_turn_template
def build_data_collator(
@@ -51,7 +155,8 @@ def build_data_collator(
- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
- `TextCompletionsCollatorWithPadding`.
+ `TextCompletionsCollatorWithPadding`. Supports optional
+ ``end_of_turn_template`` for tool-aware span-based masking.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.
@@ -126,27 +231,19 @@ def build_data_collator(
**kwargs,
)
elif collator_name == "text_completions_only_with_padding":
- # Extract instruction and response templates from kwargs if provided
- instruction_template = kwargs.pop("instruction_template", None)
- response_template = kwargs.pop("response_template", None)
-
- # Default to Llama-style templates if not provided
- instruction_prefix = (
- instruction_template
- if instruction_template
- else "<|start_header_id|>user<|end_header_id|>\n\n"
- )
- response_prefix = (
- response_template
- if response_template
- else "<|start_header_id|>assistant<|end_header_id|>\n\n"
- )
+ if not kwargs.get("response_template"):
+ raise ValueError(
+ "'text_completions_only_with_padding' requires a response_template.\n"
+ "Fix: set train_target in your data config (auto-resolves templates "
+ "from the tokenizer), or provide response_template in collator_kwargs."
+ )
return TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
- instruction_prefix=instruction_prefix,
- response_prefix=response_prefix,
debug=debug,
+ ignore_index=(
+ label_ignore_index if label_ignore_index is not None else -100
+ ),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
@@ -206,9 +303,66 @@ def build_collator_from_config(
"trust_remote_code", config.model.trust_remote_code
)
- # Merge collator_kwargs from config with the existing kwargs
- # Config kwargs take precedence over automatically determined kwargs
+ # --- Resolve train_target and templates ---
config_collator_kwargs = train_split.collator_kwargs or {}
+
+ if collator_name == "text_completions_only_with_padding":
+ if train_split.train_target is not None:
+ # Path 1: train_target is set, auto-detect templates from
+ # the tokenizer's chat template. Falls back to user-provided
+ # response_template in collator_kwargs if auto-detection fails.
+ collator_kwargs["train_target"] = train_split.train_target.value
+
+ try:
+ response_template, end_of_turn_template = _resolve_collator_templates(
+ tokenizer
+ )
+ collator_kwargs["response_template"] = response_template
+ if train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS:
+ collator_kwargs["end_of_turn_template"] = end_of_turn_template
+ except ValueError:
+ if config_collator_kwargs.get("response_template") is None:
+ raise
+
+ if (
+ train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS
+ and "end_of_turn_template" not in collator_kwargs
+ and config_collator_kwargs.get("end_of_turn_template") is None
+ ):
+ raise ValueError(
+ "train_target='all_assistant_turns' requires end_of_turn_template, "
+ "but auto-detection failed.\n"
+ "Fix: provide end_of_turn_template in collator_kwargs."
+ )
+
+ elif config_collator_kwargs.get("response_template") is not None:
+ # Path 2: train_target not set, templates provided manually
+ # via collator_kwargs. Infer train_target from which templates
+ # are present.
+ has_eot = config_collator_kwargs.get("end_of_turn_template") is not None
+ has_inst = config_collator_kwargs.get("instruction_template") is not None
+ if has_eot:
+ collator_kwargs["train_target"] = "all_assistant_turns"
+ elif has_inst:
+ warnings.warn(
+ "Instruction-based masking is deprecated.\n"
+ "Use train_target='all_assistant_turns'"
+ "or train_target='final_assistant_turn' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ collator_kwargs["train_target"] = "_legacy_instruction_response"
+ else:
+ collator_kwargs["train_target"] = "final_assistant_turn"
+ else:
+ raise ValueError(
+ "'text_completions_only_with_padding' collator requires"
+ " configuration.\n"
+ "Fix: set train_target in your data config, "
+ "or provide response_template in collator_kwargs."
+ )
+
+ # User-provided collator_kwargs override auto-resolved values
collator_kwargs.update(config_collator_kwargs)
return build_data_collator(
diff --git a/src/oumi/core/collators/text_completions_collator_with_padding.py b/src/oumi/core/collators/text_completions_collator_with_padding.py
index e567e47be9..50c814f02f 100644
--- a/src/oumi/core/collators/text_completions_collator_with_padding.py
+++ b/src/oumi/core/collators/text_completions_collator_with_padding.py
@@ -27,22 +27,34 @@ class TextCompletionsCollatorWithPadding:
def __init__(
self,
tokenizer: BaseTokenizer,
- instruction_prefix: str,
- response_prefix: str,
+ response_template: str,
+ train_target: str,
+ instruction_template: str | None = None,
debug: bool = False,
+ end_of_turn_template: str | None = None,
+ ignore_index: int = -100,
):
"""Custom collator for text LLM training.
Args:
tokenizer: The tokenizer used for encoding the data.
- instruction_prefix: The prefix marking the beginning of the user instruction.
- response_prefix: The prefix marking the beginning of the assistant response.
+ response_template: String marking assistant response start.
+ instruction_template: String marking user instruction start.
debug: If True, enables debug mode for logging.
+ train_target: Training target — ``"all_assistant_turns"``
+ or ``"final_assistant_turn"``.
+ end_of_turn_template: String marking the end of a turn.
+ Required for ``all_assistant_turns``.
+ ignore_index: Value used for masked labels. Must match the ignore_index
+ of the loss function (default: -100).
"""
self._default_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
- instruction_template=instruction_prefix,
- response_template=response_prefix,
+ instruction_template=instruction_template,
+ response_template=response_template,
+ train_target=train_target,
+ end_of_turn_template=end_of_turn_template,
+ ignore_index=ignore_index,
)
if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
@@ -55,7 +67,7 @@ def _collate(self, inputs: list[Any]) -> dict[str, Any]:
result = self._default_collator(inputs)
return result
- def __call__(self, batch) -> dict[str, Any]:
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
"""Pads to the longest length present in the batch.
Args:
diff --git a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py
index 2e47212dbb..26778ebe61 100644
--- a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py
+++ b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py
@@ -19,18 +19,64 @@
import torch
from transformers.data.data_collator import DataCollatorForLanguageModeling
+from oumi.core.configs.params.data_params import TrainTarget
+
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
- """Data collator used for completion tasks.
+ """Data collator for completion-only training.
+
+ Masks input labels so that the loss is only computed on specific
+ tokens (typically assistant responses), while ignoring other tokens
+ (system prompts, user messages, padding).
+
+ The ``train_target`` parameter selects the training target:
+
+ **``all_assistant_turns``**:
+ Span-based masking for multi-turn and tool-calling conversations.
+ Masks everything, then unmarks each assistant response span bounded
+ by ``response_template`` .. ``end_of_turn_template`` (inclusive of EOT).
+ Correctly handles interleaved tool results and parallel tool calls.
- Copied from `trl`'s `DataCollatorForCompletionOnlyLM` class.
+ **``final_assistant_turn``**:
+ Masks all tokens before the *last* ``response_template`` occurrence.
+ Only the final assistant response is trained on. Suitable for
+ single-turn completions.
+
+ Args:
+ response_template: String or token IDs marking the start of an
+ assistant response. Required for all modes.
+ instruction_template: String or token IDs marking the start of a
+ user instruction. Legacy — only used with the instruction+response
+ fallback path.
+ train_target: One of ``"all_assistant_turns"``,
+ ``"final_assistant_turn"``, ``"_legacy_instruction_response"``.
+ Resolved by the builder before construction.
+ end_of_turn_template: String or token IDs marking the end of a
+ conversational turn. Required for ``all_assistant_turns`` mode.
+ mlm: Whether to use masked language modeling. Default False.
+ ignore_index: Label value for masked tokens. Default -100.
+ padding_free: Remove padding and add position_ids. Default False.
"""
+ _VALID_TRAIN_TARGETS = {t.value for t in TrainTarget} | {
+ "_legacy_instruction_response",
+ }
+
+ def _tokenize_template(self, template: str | list[int] | None) -> list[int] | None:
+ """Encode a template string into token IDs, or pass through if already IDs."""
+ if template is None:
+ return None
+ if isinstance(template, str):
+ return self.tokenizer.encode(template, add_special_tokens=False)
+ return list(template)
+
def __init__(
self,
response_template: str | list[int],
instruction_template: str | list[int] | None = None,
*args,
+ train_target: str,
+ end_of_turn_template: str | list[int] | None = None,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
@@ -39,26 +85,33 @@ def __init__(
"""Initializes the DataCollatorForCompletionOnlyLM."""
super().__init__(*args, mlm=mlm, **kwargs)
+ # Tokenize templates.
self.instruction_template = instruction_template
- if isinstance(instruction_template, str):
- # The user provides a string, must tokenize
- self.instruction_token_ids = self.tokenizer.encode(
- self.instruction_template, # type: ignore
- add_special_tokens=False,
- )
- else:
- # The user already provides the token ids
- self.instruction_token_ids = instruction_template
-
+ self.instruction_token_ids = self._tokenize_template(instruction_template)
self.response_template = response_template
- if isinstance(response_template, str):
- # The user provides a string, must tokenize
- self.response_token_ids = self.tokenizer.encode(
- self.response_template, add_special_tokens=False
+ self.response_token_ids: list[int] = self._tokenize_template(response_template) # type: ignore[assignment]
+ self.end_of_turn_template = end_of_turn_template
+ self.end_of_turn_token_ids = self._tokenize_template(end_of_turn_template)
+
+ if train_target not in self._VALID_TRAIN_TARGETS:
+ valid = sorted(self._VALID_TRAIN_TARGETS - {"_legacy_instruction_response"})
+ raise ValueError(
+ f"Unknown train_target='{train_target}'. Must be one of: {valid}"
)
- else:
- # The user already provides the token ids
- self.response_token_ids = response_template
+ self.train_target = train_target
+
+ if self.train_target == "all_assistant_turns":
+ if end_of_turn_template is None:
+ raise ValueError(
+ "end_of_turn_template must be provided "
+ f"when train_target='{self.train_target}'"
+ )
+ if self.train_target == "_legacy_instruction_response":
+ if instruction_template is None:
+ raise ValueError(
+ "instruction_template must be provided "
+ f"when train_target='{self.train_target}'"
+ )
if (
not self.mlm
@@ -78,13 +131,102 @@ def __init__(
self.ignore_index = ignore_index
self.padding_free = padding_free
+ @staticmethod
+ def _find_pattern(seq: list[int], pattern: list[int]) -> list[int]:
+ """Return all start positions where *pattern* appears in *seq*."""
+ plen = len(pattern)
+ if plen == 0:
+ return []
+ first = pattern[0]
+ positions = []
+ for i in range(len(seq) - plen + 1):
+ if seq[i] == first and seq[i : i + plen] == pattern:
+ positions.append(i)
+ return positions
+
+ def _apply_span_masking(
+ self, batch: dict[str, Any], examples: list[list[int] | Any | dict[str, Any]]
+ ) -> None:
+ """Apply span-based masking for multi-turn conversations.
+
+ Masks all labels, then unmarks assistant response spans bounded by
+ response_template and end_of_turn_template (inclusive — the EOT token
+ is unmasked so the model learns to produce it).
+ """
+ resp_ids = self.response_token_ids
+ eot_ids = self.end_of_turn_token_ids
+ assert eot_ids is not None # Caller checks end_of_turn_template is not None
+ resp_len = len(resp_ids)
+ pad_token_id = self.tokenizer.pad_token_id
+
+ for i in range(len(examples)):
+ # Step 1: mask everything.
+ batch["labels"][i, :] = self.ignore_index
+
+ seq: list[int] = batch["input_ids"][i].tolist()
+
+ # Compute effective sequence length excluding trailing padding.
+ # Prevents false matches when end_of_turn_token_ids overlaps
+ # with the pad token (common: e.g. <|im_end|> = eos = pad).
+ if pad_token_id is not None:
+ n = len(seq)
+ while n > 0 and seq[n - 1] == pad_token_id:
+ n -= 1
+ else:
+ n = len(seq)
+
+ # Step 2: find every assistant response start position.
+ resp_positions = self._find_pattern(seq[:n], resp_ids)
+
+ if len(resp_positions) == 0:
+ warnings.warn(
+ f"Could not find response template in the following instance: "
+ f"{self.tokenizer.decode(batch['input_ids'][i])}. "
+ "This instance will be ignored in loss calculation.",
+ UserWarning,
+ )
+ continue
+
+ for resp_pos in resp_positions:
+ content_start = resp_pos + resp_len
+
+ # Step 3: find the next end_of_turn after content_start.
+ eot_positions = self._find_pattern(seq[content_start:n], eot_ids)
+ if eot_positions:
+ content_end = content_start + eot_positions[0]
+ else:
+ content_end = n
+
+ if content_start >= content_end:
+ continue
+
+ # Step 4: unmask this assistant response span, including the
+ # end-of-turn token so the model learns when to stop.
+ if eot_positions:
+ eot_len = len(self.end_of_turn_token_ids) # type: ignore
+ unmask_end = content_end + eot_len
+ else:
+ # No EOT found — content_end == n (end of real content).
+ # Do NOT extend past n or we'd unmask into padding.
+ unmask_end = content_end
+ batch["labels"][i, content_start:unmask_end] = batch["input_ids"][
+ i, content_start:unmask_end
+ ]
+
+ # ------------------------------------------------------------------
+ # Main collation
+ # ------------------------------------------------------------------
+
def torch_call(
self, examples: list[list[int] | Any | dict[str, Any]]
) -> dict[str, Any]:
"""Collates a list of examples into a batch."""
batch = super().torch_call(examples)
- if self.instruction_template is None:
+ if self.train_target == "all_assistant_turns":
+ self._apply_span_masking(batch, examples)
+ elif self.train_target == "final_assistant_turn":
+ # Response-only: unmask only the final assistant response.
for i in range(len(examples)):
response_token_ids_start_idx = None
diff --git a/src/oumi/core/configs/__init__.py b/src/oumi/core/configs/__init__.py
index 8ed6769f5d..f56cc9ffe7 100644
--- a/src/oumi/core/configs/__init__.py
+++ b/src/oumi/core/configs/__init__.py
@@ -94,6 +94,7 @@
DatasetSplit,
DatasetSplitParams,
MixtureStrategy,
+ TrainTarget,
)
from oumi.core.configs.params.evaluation_params import (
EvaluationBackend,
@@ -190,6 +191,7 @@
"LMHarnessTaskParams",
"LoraWeightInitialization",
"MixedPrecisionDtype",
+ "TrainTarget",
"MixtureStrategy",
"ModelParams",
"PeftParams",
diff --git a/src/oumi/core/configs/params/data_params.py b/src/oumi/core/configs/params/data_params.py
index 30f11a39c2..ac74a08811 100644
--- a/src/oumi/core/configs/params/data_params.py
+++ b/src/oumi/core/configs/params/data_params.py
@@ -52,6 +52,28 @@ def get_literal_value(self) -> Literal["first_exhausted", "all_exhausted"]:
raise ValueError("Unsupported value for MixtureStrategy")
+class TrainTarget(str, Enum):
+ """Controls which tokens contribute to the loss during training.
+
+ Used with the ``text_completions_only_with_padding`` collator to
+ select the training target. Template tokens are auto-resolved
+ from the tokenizer vocabulary.
+
+ Members:
+ ALL_ASSISTANT_TURNS: Train on all assistant response turns including
+ tool calls. Uses span-based masking: system prompts, user
+ messages, and tool results are masked; everything between the
+ assistant header and the end-of-turn token (inclusive) is
+ unmasked.
+ FINAL_ASSISTANT_TURN: Train only on the final assistant response.
+ Masks all tokens before the last ``response_template``
+ occurrence. Suitable for single-turn completions.
+ """
+
+ ALL_ASSISTANT_TURNS = "all_assistant_turns"
+ FINAL_ASSISTANT_TURN = "final_assistant_turn"
+
+
@dataclass
class DatasetParams(BaseParams):
dataset_name: str = MISSING
@@ -197,8 +219,13 @@ class DatasetSplitParams(BaseParams):
- "text_with_padding": Dynamically pads the inputs received to
the longest length.
+ - "text_completions_only_with_padding": Uses template matching to
+ mask non-assistant tokens. Works for simple user/assistant turns.
+ Supports optional ``end_of_turn_template`` in ``collator_kwargs``
+ for span-based masking.
- "vision_language_with_padding": Uses VisionLanguageCollator
for image+text multi-modal data.
+ - "vision_language_sft": Uses VisionLanguageSftCollator.
If None, then a default collator will be assigned.
"""
@@ -210,6 +237,16 @@ class DatasetSplitParams(BaseParams):
and can be used to customize collator behavior beyond the default parameters.
"""
+ train_target: TrainTarget | None = None
+ """High-level training target for ``text_completions_only_with_padding``.
+
+ When set, the builder auto-detects ``response_template`` and
+ ``end_of_turn_template`` from the tokenizer's chat template.
+ Use ``collator_kwargs`` to override individual auto-resolved values.
+
+ See :class:`TrainTarget` for available options.
+ """
+
pack: bool = False
"""Whether to pack the text into constant-length chunks.
@@ -266,6 +303,20 @@ class DatasetSplitParams(BaseParams):
def __post_init__(self):
"""Verifies params."""
+ # Convert string train_target to enum if needed
+ if isinstance(self.train_target, str):
+ self.train_target = TrainTarget(self.train_target)
+
+ if (
+ self.train_target is not None
+ and self.collator_name != "text_completions_only_with_padding"
+ ):
+ raise ValueError(
+ "`train_target` requires "
+ "collator_name='text_completions_only_with_padding', "
+ f"got '{self.collator_name}'."
+ )
+
if any([dataset.mixture_proportion is not None for dataset in self.datasets]):
if not all(
[dataset.mixture_proportion is not None for dataset in self.datasets]
diff --git a/tests/unit/builders/test_collators.py b/tests/unit/builders/test_collators.py
index ca2ddb86de..62db4b1956 100644
--- a/tests/unit/builders/test_collators.py
+++ b/tests/unit/builders/test_collators.py
@@ -12,6 +12,7 @@
ModelParams,
TrainingConfig,
TrainingParams,
+ TrainTarget,
)
@@ -249,3 +250,339 @@ def test_build_collator_from_config_collator_kwargs_override(mock_tokenizer):
assert callable(collator)
# Verify that the config kwargs override the model-determined kwargs
assert collator._allow_multi_image_inputs is False
+
+
+# ---------------------------------------------------------------------------
+# TrainTarget / builder auto-detection tests
+# ---------------------------------------------------------------------------
+
+
+def _chatml_tokenizer():
+ """Mock tokenizer that renders ChatML format."""
+ tok = MagicMock()
+ tok.pad_token_id = 0
+ tok.model_max_length = 2048
+
+ def _apply(messages, **kw):
+ out = "".join(
+ f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n" for m in messages
+ )
+ if kw.get("add_generation_prompt"):
+ out += "<|im_start|>assistant\n"
+ return out
+
+ tok.apply_chat_template = MagicMock(side_effect=_apply)
+
+ # The production code encodes/decodes three substrings from the
+ # rendered template. Map each to stable token IDs so the
+ # common-prefix logic works.
+ _encode_map = {
+ "<|im_end|>\n": [101, 10],
+ "<|im_end|>\n<|im_start|>user\n": [101, 10, 100, 20],
+ "<|im_end|>\n<|im_start|>assistant\n": [101, 10, 100, 30],
+ "<|im_start|>assistant\n": [100, 30],
+ }
+ _decode_map = {
+ (101, 10): "<|im_end|>\n",
+ (100, 30): "<|im_start|>assistant\n",
+ }
+ tok.encode = MagicMock(side_effect=lambda text, **kw: _encode_map[text])
+ tok.decode = MagicMock(side_effect=lambda ids, **kw: _decode_map[tuple(ids)])
+ return tok
+
+
+def _llama3_tokenizer():
+ """Mock tokenizer that renders Llama-3 format."""
+ tok = MagicMock()
+ tok.pad_token_id = 0
+ tok.model_max_length = 2048
+
+ def _apply(messages, **kw):
+ parts = ["<|begin_of_text|>"]
+ for m in messages:
+ parts.append(
+ f"<|start_header_id|>{m['role']}<|end_header_id|>\n\n"
+ f"{m['content']}<|eot_id|>"
+ )
+ if kw.get("add_generation_prompt"):
+ parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
+ return "".join(parts)
+
+ tok.apply_chat_template = MagicMock(side_effect=_apply)
+
+ _encode_map = {
+ "<|eot_id|>": [203],
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n": [
+ 203,
+ 201,
+ 20,
+ 202,
+ 10,
+ ],
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n": [
+ 203,
+ 201,
+ 30,
+ 202,
+ 10,
+ ],
+ "<|start_header_id|>assistant<|end_header_id|>\n\n": [201, 30, 202, 10],
+ }
+ _decode_map = {
+ (203,): "<|eot_id|>",
+ (201, 30, 202, 10): "<|start_header_id|>assistant<|end_header_id|>\n\n",
+ }
+ tok.encode = MagicMock(side_effect=lambda text, **kw: _encode_map[text])
+ tok.decode = MagicMock(side_effect=lambda ids, **kw: _decode_map[tuple(ids)])
+ return tok
+
+
+def _unknown_tokenizer():
+ """Mock tokenizer with no chat template."""
+ tok = MagicMock()
+ tok.pad_token_id = 0
+ tok.model_max_length = 2048
+ tok.apply_chat_template = MagicMock(
+ side_effect=Exception("No chat template configured")
+ )
+ return tok
+
+
+def test_build_data_collator_text_completions_with_tool_kwargs(mock_tokenizer):
+ """Build completions collator with end_of_turn_template + custom ignore index."""
+ collator = build_data_collator(
+ "text_completions_only_with_padding",
+ mock_tokenizer,
+ max_length=512,
+ label_ignore_index=-200,
+ response_template="<|assistant|>",
+ end_of_turn_template="<|end|>",
+ train_target="all_assistant_turns",
+ )
+ assert collator is not None
+ assert callable(collator)
+ inner = collator._default_collator
+ assert inner.ignore_index == -200
+
+
+def test_train_target_all_assistant_turns():
+ """ChatML auto-detection with ALL_ASSISTANT_TURNS train target."""
+ tok = _chatml_tokenizer()
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ train_target=TrainTarget.ALL_ASSISTANT_TURNS,
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=tok)
+ assert collator is not None
+ inner = collator._default_collator
+ assert inner.response_template == "<|im_start|>assistant\n"
+
+
+def test_train_target_final_assistant_turn():
+ """ChatML auto-detection with FINAL_ASSISTANT_TURN train target."""
+ tok = _chatml_tokenizer()
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ train_target=TrainTarget.FINAL_ASSISTANT_TURN,
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=tok)
+ assert collator is not None
+ inner = collator._default_collator
+ assert inner.response_template == "<|im_start|>assistant\n"
+
+
+def test_train_target_llama3():
+ """Llama-3 auto-detection with ALL_ASSISTANT_TURNS train target."""
+ tok = _llama3_tokenizer()
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ train_target=TrainTarget.ALL_ASSISTANT_TURNS,
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=tok)
+ assert collator is not None
+ inner = collator._default_collator
+ assert (
+ inner.response_template == "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+
+
+def test_train_target_unknown_tokenizer():
+ """Error when tokenizer vocab does not match any known chat format."""
+ tok = _unknown_tokenizer()
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ train_target=TrainTarget.ALL_ASSISTANT_TURNS,
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ with pytest.raises(ValueError, match="no chat template"):
+ build_collator_from_config(config, tokenizer=tok)
+
+
+def test_train_target_with_collator_kwargs_override():
+ """collator_kwargs overrides auto-resolved templates when train_target is set."""
+ tok = _chatml_tokenizer()
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ train_target=TrainTarget.ALL_ASSISTANT_TURNS,
+ collator_kwargs={"response_template": "<|im_end|>\n"},
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=tok)
+ assert collator is not None
+ inner = collator._default_collator
+ # Auto-resolved would be "<|im_start|>assistant\n"; user override wins
+ assert inner.response_template == "<|im_end|>\n"
+
+
+def test_train_target_on_wrong_collator():
+ """train_target is only valid for text_completions_only_with_padding."""
+ with pytest.raises(ValueError, match="train_target.*requires"):
+ DatasetSplitParams(
+ collator_name="text_with_padding",
+ train_target=TrainTarget.ALL_ASSISTANT_TURNS,
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+
+
+def test_legacy_instruction_template_backward_compat(mock_tokenizer):
+ """Legacy path: instruction_template + response_template → _legacy + warning."""
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ collator_kwargs={
+ "response_template": "<|assistant|>",
+ "instruction_template": "<|user|>",
+ },
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ with pytest.warns(
+ DeprecationWarning, match="Instruction-based masking is deprecated"
+ ):
+ collator = build_collator_from_config(config, tokenizer=mock_tokenizer)
+ assert collator is not None
+ inner = collator._default_collator
+ assert inner.response_template == "<|assistant|>"
+ assert inner.instruction_template == "<|user|>"
+ assert inner.train_target == "_legacy_instruction_response"
+
+
+def test_bare_collator_name_raises_without_templates(mock_tokenizer):
+ """Bare collator_name without kwargs or train_target raises an error."""
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ with pytest.raises(ValueError, match="response_template"):
+ build_collator_from_config(config, tokenizer=mock_tokenizer)
+
+
+def test_old_recipe_response_only_sets_final(mock_tokenizer):
+ """Old recipe: response_template only → final_assistant_turn."""
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ collator_kwargs={
+ "response_template": "<|assistant|>",
+ },
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=mock_tokenizer)
+ assert collator is not None
+ assert collator._default_collator.train_target == "final_assistant_turn"
+
+
+def test_old_recipe_eot_sets_all_assistant(mock_tokenizer):
+ """Old recipe: response_template + end_of_turn_template → all_assistant_turns."""
+ config = TrainingConfig(
+ data=DataParams(
+ train=DatasetSplitParams(
+ collator_name="text_completions_only_with_padding",
+ collator_kwargs={
+ "response_template": "<|assistant|>",
+ "end_of_turn_template": "<|end|>",
+ },
+ datasets=[DatasetParams(dataset_name="dummy", split="train")],
+ )
+ ),
+ model=ModelParams(
+ model_name="MlpEncoder",
+ tokenizer_name="openai-community/gpt2",
+ model_max_length=512,
+ ),
+ )
+ collator = build_collator_from_config(config, tokenizer=mock_tokenizer)
+ assert collator is not None
+ assert collator._default_collator.train_target == "all_assistant_turns"
diff --git a/tests/unit/core/collators/test_text_completions_collator_with_padding.py b/tests/unit/core/collators/test_text_completions_collator_with_padding.py
index f98b618ec2..258784d507 100644
--- a/tests/unit/core/collators/test_text_completions_collator_with_padding.py
+++ b/tests/unit/core/collators/test_text_completions_collator_with_padding.py
@@ -1,10 +1,12 @@
import functools
+import warnings
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
+import oumi.core.constants as constants
from oumi.builders import build_tokenizer
from oumi.core.collators.text_completions_collator_with_padding import (
TextCompletionsCollatorWithPadding,
@@ -13,6 +15,15 @@
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils import logging
+IGNORE = constants.LABEL_IGNORE_INDEX
+
+# Template strings for span-masking tests — chosen to be unambiguous in GPT-2's vocab.
+_RESP_STR = " ASSISTANT_RESPONSE_START"
+_EOT_STR = " TURN_ENDS_HERE"
+
+# Arbitrary token IDs used as "content" that must not appear in any template.
+_SENTINELS = [601, 602, 603, 604, 605, 606, 607, 608]
+
@pytest.fixture
def mock_tokenizer():
@@ -50,8 +61,9 @@ def test_success_basic():
collator = TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
- instruction_prefix=instruction_prefix,
- response_prefix=response_prefix,
+ instruction_template=instruction_prefix,
+ response_template=response_prefix,
+ train_target="_legacy_instruction_response",
)
assert callable(collator)
@@ -174,8 +186,9 @@ def test_debug_logging(caplog):
collator = TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
- instruction_prefix=instruction_prefix,
- response_prefix=response_prefix,
+ instruction_template=instruction_prefix,
+ response_template=response_prefix,
+ train_target="_legacy_instruction_response",
debug=True,
)
assert callable(collator)
@@ -238,3 +251,228 @@ def test_debug_logging(caplog):
assert "'input_ids':" in log_text
assert "'attention_mask':" in log_text
assert "'labels':" in log_text
+
+
+# ===========================================================================
+# Span-based masking tests (tool-aware collation)
+# ===========================================================================
+
+
+@functools.cache
+def get_template_token_ids() -> tuple[list[int], list[int]]:
+ """Return (resp_ids, eot_ids) encoded once and cached."""
+ tokenizer, _ = create_test_tokenizer()
+ resp = tokenizer.encode(_RESP_STR, add_special_tokens=False)
+ eot = tokenizer.encode(_EOT_STR, add_special_tokens=False)
+ forbidden = set(resp) | set(eot)
+ for sentinel in _SENTINELS:
+ assert sentinel not in forbidden, (
+ f"Sentinel {sentinel} collides with a template token ID. Adjust _SENTINELS."
+ )
+ return resp, eot
+
+
+def make_span_collator() -> TextCompletionsCollatorWithPadding:
+ tokenizer, _ = create_test_tokenizer()
+ return TextCompletionsCollatorWithPadding(
+ tokenizer=tokenizer,
+ response_template=_RESP_STR,
+ train_target="all_assistant_turns",
+ end_of_turn_template=_EOT_STR,
+ )
+
+
+def get_span_labels(collator, seq: list[int]) -> list[int]:
+ return collator([{"input_ids": seq}])["labels"][0].tolist()
+
+
+def flat(*parts: list[int]) -> list[int]:
+ result = []
+ for p in parts:
+ result.extend(p)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Single assistant turn
+# ---------------------------------------------------------------------------
+
+
+def test_span_single_turn_content_is_unmasked():
+ resp, eot = get_template_token_ids()
+ prefix = [_SENTINELS[0], _SENTINELS[1]]
+ content = [_SENTINELS[2], _SENTINELS[3]]
+ seq = flat(prefix, resp, content, eot)
+
+ labels = get_span_labels(make_span_collator(), seq)
+
+ n_prefix = len(prefix) + len(resp)
+ assert all(v == IGNORE for v in labels[:n_prefix])
+ assert labels[n_prefix : n_prefix + len(content)] == content
+ # EOT tokens are unmasked (model learns to produce the stop token)
+ assert labels[n_prefix + len(content) : n_prefix + len(content) + len(eot)] == eot
+
+
+# ---------------------------------------------------------------------------
+# Multiple assistant turns
+# ---------------------------------------------------------------------------
+
+
+def test_span_two_turns_both_unmasked():
+ resp, eot = get_template_token_ids()
+ turn1 = [_SENTINELS[0], _SENTINELS[1]]
+ middle = [_SENTINELS[2]]
+ turn2 = [_SENTINELS[3], _SENTINELS[4]]
+ seq = flat(resp, turn1, eot, middle, resp, turn2, eot)
+
+ labels = get_span_labels(make_span_collator(), seq)
+
+ t1_start = len(resp)
+ t1_end = t1_start + len(turn1)
+ assert labels[t1_start:t1_end] == turn1
+
+ t2_start = t1_end + len(eot) + len(middle) + len(resp)
+ t2_end = t2_start + len(turn2)
+ assert labels[t2_start:t2_end] == turn2
+
+
+def test_span_content_between_turns_is_masked():
+ resp, eot = get_template_token_ids()
+ turn1 = [_SENTINELS[0]]
+ between = [_SENTINELS[1], _SENTINELS[2]]
+ turn2 = [_SENTINELS[3]]
+ seq = flat(resp, turn1, eot, between, resp, turn2, eot)
+
+ labels = get_span_labels(make_span_collator(), seq)
+
+ between_start = len(resp) + len(turn1) + len(eot)
+ for i in range(len(between)):
+ assert labels[between_start + i] == IGNORE
+
+
+def test_span_masking_requires_end_of_turn_template():
+ tokenizer, _ = create_test_tokenizer()
+ with pytest.raises(ValueError, match="end_of_turn_template"):
+ TextCompletionsCollatorWithPadding(
+ tokenizer=tokenizer,
+ response_template=_RESP_STR,
+ train_target="all_assistant_turns",
+ end_of_turn_template=None,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+
+
+def test_span_no_response_template_all_masked():
+ seq = [_SENTINELS[0], _SENTINELS[1], _SENTINELS[2]]
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ labels = get_span_labels(make_span_collator(), seq)
+
+ assert all(v == IGNORE for v in labels)
+ assert any("response template" in str(w.message).lower() for w in caught)
+
+
+def test_span_no_eot_unmasked_to_end_of_sequence():
+ resp, _ = get_template_token_ids()
+ content = [_SENTINELS[0], _SENTINELS[1]]
+ seq = flat(resp, content)
+
+ labels = get_span_labels(make_span_collator(), seq)
+
+ assert labels[len(resp) :] == content
+
+
+def test_span_empty_content_span():
+ resp, eot = get_template_token_ids()
+ seq = flat(resp, eot)
+
+ labels = get_span_labels(make_span_collator(), seq)
+
+ assert all(v == IGNORE for v in labels)
+
+
+def test_span_padding_matching_eot_does_not_false_match():
+ """When pad_token_id matches the EOT token, padding must not be treated as
+ a real end-of-turn boundary."""
+ tokenizer, pad_token_id = create_test_tokenizer()
+ resp, _ = get_template_token_ids()
+
+ # Use pad_token_id itself as the EOT template — worst case scenario.
+ eot_ids = [pad_token_id]
+ content = [_SENTINELS[0], _SENTINELS[1]]
+ # Sequence: [RESP] content (no real EOT), then padding
+ seq = flat(resp, content) + [pad_token_id] * 5
+
+ collator = TextCompletionsCollatorWithPadding(
+ tokenizer=tokenizer,
+ response_template=_RESP_STR,
+ train_target="all_assistant_turns",
+ end_of_turn_template=str(tokenizer.decode(eot_ids)),
+ )
+ batch = collator([{"input_ids": seq}])
+ labels = batch["labels"][0].tolist()
+
+ # Content should be unmasked — the padding should not act as an EOT.
+ content_start = len(resp)
+ assert labels[content_start : content_start + len(content)] == content
+ # Padding should be masked.
+ assert all(v == IGNORE for v in labels[content_start + len(content) :])
+
+
+# ---------------------------------------------------------------------------
+# Batch processing
+# ---------------------------------------------------------------------------
+
+
+def test_span_batch_two_examples_processed_independently():
+ resp, eot = get_template_token_ids()
+ _, pad_token_id = create_test_tokenizer()
+ content_a = [_SENTINELS[0], _SENTINELS[1]]
+ content_b = [_SENTINELS[2]]
+ seq_a = flat(resp, content_a, eot)
+ seq_b = flat(resp, content_b, eot)
+
+ max_len = max(len(seq_a), len(seq_b))
+ pad_a = [pad_token_id] * (max_len - len(seq_a))
+ pad_b = [pad_token_id] * (max_len - len(seq_b))
+
+ collator = make_span_collator()
+ batch = collator([{"input_ids": seq_a + pad_a}, {"input_ids": seq_b + pad_b}])
+ labels_a = batch["labels"][0].tolist()
+ labels_b = batch["labels"][1].tolist()
+
+ assert labels_a[len(resp) : len(resp) + len(content_a)] == content_a
+ assert labels_b[len(resp) : len(resp) + len(content_b)] == content_b
+
+
+def test_span_batch_bad_example_does_not_affect_others():
+ resp, eot = get_template_token_ids()
+ _, pad_token_id = create_test_tokenizer()
+ good_seq = flat(resp, [_SENTINELS[0]], eot)
+ bad_seq = [_SENTINELS[1], _SENTINELS[2]]
+
+ max_len = max(len(good_seq), len(bad_seq))
+ pad_good = [pad_token_id] * (max_len - len(good_seq))
+ pad_bad = [pad_token_id] * (max_len - len(bad_seq))
+
+ collator = make_span_collator()
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ batch = collator(
+ [{"input_ids": good_seq + pad_good}, {"input_ids": bad_seq + pad_bad}]
+ )
+
+ assert batch["labels"][0].tolist()[len(resp)] == _SENTINELS[0]
+ assert all(v == IGNORE for v in batch["labels"][1].tolist())
+
+
+def test_span_labels_shape_matches_input_ids():
+ resp, eot = get_template_token_ids()
+ seq = flat(resp, [_SENTINELS[0], _SENTINELS[1]], eot)
+ batch = make_span_collator()([{"input_ids": seq}])
+ assert batch["labels"].shape == batch["input_ids"].shape
diff --git a/tests/unit/core/datasets/test_base_sft_dataset.py b/tests/unit/core/datasets/test_base_sft_dataset.py
index 5d51415fa0..74afa92bdf 100644
--- a/tests/unit/core/datasets/test_base_sft_dataset.py
+++ b/tests/unit/core/datasets/test_base_sft_dataset.py
@@ -26,8 +26,9 @@ def _get_hf_collator_result(conversation, tokenizer):
collator = TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
- instruction_prefix=_INSTRUCTION_PREFIX,
- response_prefix=_RESPONSE_PREFIX,
+ instruction_template=_INSTRUCTION_PREFIX,
+ response_template=_RESPONSE_PREFIX,
+ train_target="_legacy_instruction_response",
)
return collator(batch)