diff --git a/src/oumi/builders/collators.py b/src/oumi/builders/collators.py index 4656b29b97..468f821841 100644 --- a/src/oumi/builders/collators.py +++ b/src/oumi/builders/collators.py @@ -51,7 +51,8 @@ def build_data_collator( - "text_with_padding": Uses `TextCollatorWithPadding`. - "text_completions_only_with_padding": Uses - `TextCompletionsCollatorWithPadding`. + `TextCompletionsCollatorWithPadding`. Supports optional + ``end_of_turn_template`` for tool-aware span-based masking. - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. - "vision_language_sft": Uses `VisionLanguageSftCollator`. @@ -129,13 +130,22 @@ def build_data_collator( # Extract instruction and response templates from kwargs if provided instruction_template = kwargs.pop("instruction_template", None) response_template = kwargs.pop("response_template", None) + end_of_turn_template = kwargs.pop("end_of_turn_template", None) + mask_tool_calls = kwargs.pop("mask_tool_calls", False) + tool_call_start_template = kwargs.pop("tool_call_start_template", None) + + # Only default to Llama-style instruction template when NOT using + # span-based masking (end_of_turn_template makes instruction_prefix + # unnecessary since masking is handled by response/eot spans). + if end_of_turn_template is None: + instruction_prefix = ( + instruction_template + if instruction_template + else "<|start_header_id|>user<|end_header_id|>\n\n" + ) + else: + instruction_prefix = instruction_template # may be None, that's fine - # Default to Llama-style templates if not provided - instruction_prefix = ( - instruction_template - if instruction_template - else "<|start_header_id|>user<|end_header_id|>\n\n" - ) response_prefix = ( response_template if response_template @@ -147,6 +157,12 @@ def build_data_collator( instruction_prefix=instruction_prefix, response_prefix=response_prefix, debug=debug, + end_of_turn_template=end_of_turn_template, + mask_tool_calls=mask_tool_calls, + tool_call_start_template=tool_call_start_template, + ignore_index=( + label_ignore_index if label_ignore_index is not None else -100 + ), **kwargs, ) raise ValueError(f"Unknown data collator name: '{collator_name}'") diff --git a/src/oumi/core/collators/text_completions_collator_with_padding.py b/src/oumi/core/collators/text_completions_collator_with_padding.py index e567e47be9..6622cb3153 100644 --- a/src/oumi/core/collators/text_completions_collator_with_padding.py +++ b/src/oumi/core/collators/text_completions_collator_with_padding.py @@ -27,22 +27,38 @@ class TextCompletionsCollatorWithPadding: def __init__( self, tokenizer: BaseTokenizer, - instruction_prefix: str, - response_prefix: str, + response_prefix: str | list[int], + instruction_prefix: str | list[int] | None = None, debug: bool = False, + end_of_turn_template: str | list[int] | None = None, + mask_tool_calls: bool = False, + tool_call_start_template: str | list[int] | None = None, + ignore_index: int = -100, ): """Custom collator for text LLM training. Args: tokenizer: The tokenizer used for encoding the data. - instruction_prefix: The prefix marking the beginning of the user instruction. response_prefix: The prefix marking the beginning of the assistant response. + instruction_prefix: The prefix marking the beginning of the user instruction. + Optional when using span-based masking via end_of_turn_template. debug: If True, enables debug mode for logging. + end_of_turn_template: String or token-ID list marking end of turn. + When provided, enables span-based masking for tool-aware conversations. + mask_tool_calls: When True, re-masks assistant spans containing tool calls. + tool_call_start_template: String or token-ID list marking tool-call start. + Required when mask_tool_calls=True. + ignore_index: Value used for masked labels. Must match the ignore_index + of the loss function (default: -100). """ self._default_collator = DataCollatorForCompletionOnlyLM( tokenizer=tokenizer, instruction_template=instruction_prefix, response_template=response_prefix, + end_of_turn_template=end_of_turn_template, + mask_tool_calls=mask_tool_calls, + tool_call_start_template=tool_call_start_template, + ignore_index=ignore_index, ) if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None: diff --git a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py index 2e47212dbb..de742c1e11 100644 --- a/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py +++ b/src/oumi/core/collators/trl_data_collator_for_completion_only_lm.py @@ -31,6 +31,9 @@ def __init__( response_template: str | list[int], instruction_template: str | list[int] | None = None, *args, + end_of_turn_template: str | list[int] | None = None, + mask_tool_calls: bool = False, + tool_call_start_template: str | list[int] | None = None, mlm: bool = False, ignore_index: int = -100, padding_free: bool = False, @@ -60,6 +63,32 @@ def __init__( # The user already provides the token ids self.response_token_ids = response_template + # Tool-aware span-based masking parameters + self.end_of_turn_template = end_of_turn_template + if isinstance(end_of_turn_template, str): + self.end_of_turn_token_ids: list[int] | None = self.tokenizer.encode( + end_of_turn_template, add_special_tokens=False + ) + elif end_of_turn_template is not None: + self.end_of_turn_token_ids = list(end_of_turn_template) + else: + self.end_of_turn_token_ids = None + + self.mask_tool_calls = mask_tool_calls + self.tool_call_start_token_ids: list[int] | None = None + if mask_tool_calls: + if tool_call_start_template is None: + raise ValueError( + "tool_call_start_template must be provided " + "when mask_tool_calls=True" + ) + if isinstance(tool_call_start_template, str): + self.tool_call_start_token_ids = self.tokenizer.encode( + tool_call_start_template, add_special_tokens=False + ) + else: + self.tool_call_start_token_ids = list(tool_call_start_template) + if ( not self.mlm and self.instruction_template @@ -78,13 +107,116 @@ def __init__( self.ignore_index = ignore_index self.padding_free = padding_free + @staticmethod + def _find_pattern(seq: list[int], pattern: list[int]) -> list[int]: + """Return all start positions where *pattern* appears in *seq*.""" + plen = len(pattern) + if plen == 0: + return [] + first = pattern[0] + positions = [] + for i in range(len(seq) - plen + 1): + if seq[i] == first and seq[i : i + plen] == pattern: + positions.append(i) + return positions + + @staticmethod + def _span_contains( + seq: list[int], span_start: int, span_end: int, pattern: list[int] + ) -> bool: + """Return True if *pattern* appears anywhere in seq[span_start:span_end].""" + plen = len(pattern) + if plen == 0: + return False + first = pattern[0] + for i in range(span_start, span_end - plen + 1): + if seq[i] == first and seq[i : i + plen] == pattern: + return True + return False + + def _apply_span_masking( + self, batch: dict[str, Any], examples: list[list[int] | Any | dict[str, Any]] + ) -> None: + """Apply span-based label masking for tool-aware conversations. + + This masks everything, then selectively unmasks assistant response + spans delimited by response_template and end_of_turn_template. + """ + resp_ids = self.response_token_ids + eot_ids = self.end_of_turn_token_ids + assert eot_ids is not None # Caller checks end_of_turn_template is not None + resp_len = len(resp_ids) + pad_token_id = self.tokenizer.pad_token_id + + for i in range(len(examples)): + # Step 1: mask everything. + batch["labels"][i, :] = self.ignore_index + + seq: list[int] = batch["input_ids"][i].tolist() + + # Compute effective sequence length excluding trailing padding. + # Prevents false matches when end_of_turn_token_ids overlaps + # with the pad token (common: e.g. <|im_end|> = eos = pad). + if pad_token_id is not None: + n = len(seq) + while n > 0 and seq[n - 1] == pad_token_id: + n -= 1 + else: + n = len(seq) + + # Step 2: find every assistant response start position. + resp_positions = self._find_pattern(seq[:n], resp_ids) + + if len(resp_positions) == 0: + warnings.warn( + f"Could not find response template in the following instance: " + f"{self.tokenizer.decode(batch['input_ids'][i])}. " + "This instance will be ignored in loss calculation.", + UserWarning, + ) + continue + + for resp_pos in resp_positions: + content_start = resp_pos + resp_len + + # Step 3: find the next end_of_turn after content_start. + eot_positions = self._find_pattern(seq[content_start:n], eot_ids) + if eot_positions: + content_end = content_start + eot_positions[0] + else: + content_end = n + + if content_start >= content_end: + continue + + # Step 4: optionally skip tool-call spans. + if self.mask_tool_calls and self.tool_call_start_token_ids is not None: + if self._span_contains( + seq, + content_start, + content_end, + self.tool_call_start_token_ids, + ): + continue + + # Step 5: unmask this assistant response span. + batch["labels"][i, content_start:content_end] = batch["input_ids"][ + i, content_start:content_end + ] + + # ------------------------------------------------------------------ + # Main collation + # ------------------------------------------------------------------ + def torch_call( self, examples: list[list[int] | Any | dict[str, Any]] ) -> dict[str, Any]: """Collates a list of examples into a batch.""" batch = super().torch_call(examples) - if self.instruction_template is None: + if self.end_of_turn_template is not None: + self._apply_span_masking(batch, examples) + elif self.instruction_template is None: for i in range(len(examples)): response_token_ids_start_idx = None diff --git a/src/oumi/core/configs/params/data_params.py b/src/oumi/core/configs/params/data_params.py index 30f11a39c2..fbdfef2b58 100644 --- a/src/oumi/core/configs/params/data_params.py +++ b/src/oumi/core/configs/params/data_params.py @@ -197,8 +197,14 @@ class DatasetSplitParams(BaseParams): - "text_with_padding": Dynamically pads the inputs received to the longest length. + - "text_completions_only_with_padding": Uses template matching to + mask non-assistant tokens. Works for simple user/assistant turns. + Supports optional ``end_of_turn_template`` in ``collator_kwargs`` + for tool-aware span-based masking. When set, also supports + ``mask_tool_calls=True`` and ``tool_call_start_template``. - "vision_language_with_padding": Uses VisionLanguageCollator for image+text multi-modal data. + - "vision_language_sft": Uses VisionLanguageSftCollator. If None, then a default collator will be assigned. """ diff --git a/tests/unit/builders/test_collators.py b/tests/unit/builders/test_collators.py index ca2ddb86de..4ea64ba971 100644 --- a/tests/unit/builders/test_collators.py +++ b/tests/unit/builders/test_collators.py @@ -146,6 +146,50 @@ def test_build_collator_from_config_with_collator( assert callable(collator) +def test_build_data_collator_text_completions_with_tool_kwargs(mock_tokenizer): + collator_name = "text_completions_only_with_padding" + resp = "<|im_start|>" + "assistant\n" + eot = "<|im_end|>" + + # Basic build with end_of_turn_template + collator = build_data_collator( + collator_name, + mock_tokenizer, + max_length=None, + response_template=resp, + end_of_turn_template=eot, + ) + assert collator is not None + assert callable(collator) + + # Default label_ignore_index is forwarded + assert collator._default_collator.ignore_index == constants.LABEL_IGNORE_INDEX + + # Custom label_ignore_index is forwarded + collator_custom = build_data_collator( + collator_name, + mock_tokenizer, + max_length=None, + label_ignore_index=-200, + response_template=resp, + end_of_turn_template=eot, + ) + assert collator_custom._default_collator.ignore_index == -200 + + # With mask_tool_calls + collator_tc = build_data_collator( + collator_name, + mock_tokenizer, + max_length=None, + response_template=resp, + end_of_turn_template=eot, + mask_tool_calls=True, + tool_call_start_template="", + ) + assert collator_tc is not None + assert callable(collator_tc) + + def test_build_collator_from_config_no_collator(mock_tokenizer): training_config = TrainingConfig( data=DataParams( diff --git a/tests/unit/core/collators/test_text_completions_collator_with_padding.py b/tests/unit/core/collators/test_text_completions_collator_with_padding.py index f98b618ec2..32d898f9fe 100644 --- a/tests/unit/core/collators/test_text_completions_collator_with_padding.py +++ b/tests/unit/core/collators/test_text_completions_collator_with_padding.py @@ -1,10 +1,12 @@ import functools +import warnings from unittest.mock import MagicMock import numpy as np import pytest import torch +import oumi.core.constants as constants from oumi.builders import build_tokenizer from oumi.core.collators.text_completions_collator_with_padding import ( TextCompletionsCollatorWithPadding, @@ -13,6 +15,16 @@ from oumi.core.tokenizers.base_tokenizer import BaseTokenizer from oumi.utils import logging +IGNORE = constants.LABEL_IGNORE_INDEX + +# Template strings for span-masking tests — chosen to be unambiguous in GPT-2's vocab. +_RESP_STR = " ASSISTANT_RESPONSE_START" +_EOT_STR = " TURN_ENDS_HERE" +_TC_STR = " TOOL_CALL_BEGINS" + +# Arbitrary token IDs used as "content" that must not appear in any template. +_SENTINELS = [601, 602, 603, 604, 605, 606, 607, 608] + @pytest.fixture def mock_tokenizer(): @@ -238,3 +250,347 @@ def test_debug_logging(caplog): assert "'input_ids':" in log_text assert "'attention_mask':" in log_text assert "'labels':" in log_text + + +# =========================================================================== +# Span-based masking tests (tool-aware collation) +# =========================================================================== + + +@functools.cache +def get_template_token_ids() -> tuple[list[int], list[int], list[int]]: + """Return (resp_ids, eot_ids, tc_ids) encoded once and cached.""" + tokenizer, _ = create_test_tokenizer() + resp = tokenizer.encode(_RESP_STR, add_special_tokens=False) + eot = tokenizer.encode(_EOT_STR, add_special_tokens=False) + tc = tokenizer.encode(_TC_STR, add_special_tokens=False) + forbidden = set(resp) | set(eot) | set(tc) + for sentinel in _SENTINELS: + assert sentinel not in forbidden, ( + f"Sentinel {sentinel} collides with a template token ID. Adjust _SENTINELS." + ) + return resp, eot, tc + + +def make_span_collator( + mask_tool_calls: bool = False, +) -> TextCompletionsCollatorWithPadding: + tokenizer, _ = create_test_tokenizer() + resp_ids, eot_ids, tc_ids = get_template_token_ids() + return TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_prefix=resp_ids, + end_of_turn_template=eot_ids, + mask_tool_calls=mask_tool_calls, + tool_call_start_template=tc_ids if mask_tool_calls else None, + ) + + +def get_span_labels(collator, seq: list[int]) -> list[int]: + return collator([{"input_ids": seq}])["labels"][0].tolist() + + +def flat(*parts: list[int]) -> list[int]: + result = [] + for p in parts: + result.extend(p) + return result + + +# --------------------------------------------------------------------------- +# Single assistant turn +# --------------------------------------------------------------------------- + + +def test_span_single_turn_content_is_unmasked(): + resp, eot, _ = get_template_token_ids() + prefix = [_SENTINELS[0], _SENTINELS[1]] + content = [_SENTINELS[2], _SENTINELS[3]] + seq = flat(prefix, resp, content, eot) + + labels = get_span_labels(make_span_collator(), seq) + + n_prefix = len(prefix) + len(resp) + assert all(v == IGNORE for v in labels[:n_prefix]) + assert labels[n_prefix : n_prefix + len(content)] == content + assert all(v == IGNORE for v in labels[n_prefix + len(content) :]) + + +def test_span_single_turn_response_template_tokens_are_masked(): + resp, eot, _ = get_template_token_ids() + seq = flat(resp, [_SENTINELS[0]], eot) + + labels = get_span_labels(make_span_collator(), seq) + + for i in range(len(resp)): + assert labels[i] == IGNORE, f"resp template token {i} should be masked" + + +def test_span_single_turn_eot_tokens_are_masked(): + resp, eot, _ = get_template_token_ids() + content = [_SENTINELS[0]] + seq = flat(resp, content, eot) + + labels = get_span_labels(make_span_collator(), seq) + + eot_start = len(resp) + len(content) + for i in range(len(eot)): + assert labels[eot_start + i] == IGNORE, f"eot token {i} should be masked" + + +# --------------------------------------------------------------------------- +# Multiple assistant turns +# --------------------------------------------------------------------------- + + +def test_span_two_turns_both_unmasked(): + resp, eot, _ = get_template_token_ids() + turn1 = [_SENTINELS[0], _SENTINELS[1]] + middle = [_SENTINELS[2]] + turn2 = [_SENTINELS[3], _SENTINELS[4]] + seq = flat(resp, turn1, eot, middle, resp, turn2, eot) + + labels = get_span_labels(make_span_collator(), seq) + + t1_start = len(resp) + t1_end = t1_start + len(turn1) + assert labels[t1_start:t1_end] == turn1 + + t2_start = t1_end + len(eot) + len(middle) + len(resp) + t2_end = t2_start + len(turn2) + assert labels[t2_start:t2_end] == turn2 + + +def test_span_content_between_turns_is_masked(): + resp, eot, _ = get_template_token_ids() + turn1 = [_SENTINELS[0]] + between = [_SENTINELS[1], _SENTINELS[2]] + turn2 = [_SENTINELS[3]] + seq = flat(resp, turn1, eot, between, resp, turn2, eot) + + labels = get_span_labels(make_span_collator(), seq) + + between_start = len(resp) + len(turn1) + len(eot) + for i in range(len(between)): + assert labels[between_start + i] == IGNORE + + +# --------------------------------------------------------------------------- +# Tool result masking +# --------------------------------------------------------------------------- + + +def test_span_tool_result_is_masked(): + resp, eot, _ = get_template_token_ids() + tool_call_content = [_SENTINELS[0], _SENTINELS[1]] + tool_result = [_SENTINELS[2], _SENTINELS[3]] + final_answer = [_SENTINELS[4], _SENTINELS[5]] + seq = flat(resp, tool_call_content, eot, tool_result, resp, final_answer, eot) + + labels = get_span_labels(make_span_collator(), seq) + + tool_result_start = len(resp) + len(tool_call_content) + len(eot) + for i in range(len(tool_result)): + assert labels[tool_result_start + i] == IGNORE + + +def test_span_final_answer_after_tool_result_is_unmasked(): + resp, eot, _ = get_template_token_ids() + tool_call_content = [_SENTINELS[0]] + tool_result = [_SENTINELS[1]] + final_answer = [_SENTINELS[2], _SENTINELS[3]] + seq = flat(resp, tool_call_content, eot, tool_result, resp, final_answer, eot) + + labels = get_span_labels(make_span_collator(), seq) + + final_start = ( + len(resp) + len(tool_call_content) + len(eot) + len(tool_result) + len(resp) + ) + assert labels[final_start : final_start + len(final_answer)] == final_answer + + +# --------------------------------------------------------------------------- +# mask_tool_calls option +# --------------------------------------------------------------------------- + + +def test_span_tool_call_turn_unmasked_by_default(): + resp, eot, tc = get_template_token_ids() + tc_content = flat(tc, [_SENTINELS[0]]) + seq = flat(resp, tc_content, eot) + + labels = get_span_labels(make_span_collator(mask_tool_calls=False), seq) + + content_start = len(resp) + assert labels[content_start : content_start + len(tc_content)] == tc_content + + +def test_span_tool_call_turn_masked_when_option_set(): + resp, eot, tc = get_template_token_ids() + tc_content = flat(tc, [_SENTINELS[0]]) + seq = flat(resp, tc_content, eot) + + labels = get_span_labels(make_span_collator(mask_tool_calls=True), seq) + + content_start = len(resp) + assert all( + v == IGNORE for v in labels[content_start : content_start + len(tc_content)] + ) + + +def test_span_non_tool_call_turn_still_unmasked_when_mask_tool_calls_set(): + resp, eot, tc = get_template_token_ids() + tc_content = flat(tc, [_SENTINELS[0]]) + final_answer = [_SENTINELS[1], _SENTINELS[2]] + seq = flat(resp, tc_content, eot, resp, final_answer, eot) + + labels = get_span_labels(make_span_collator(mask_tool_calls=True), seq) + + final_start = len(resp) + len(tc_content) + len(eot) + len(resp) + assert labels[final_start : final_start + len(final_answer)] == final_answer + + +def test_span_mask_tool_calls_requires_template(): + tokenizer, _ = create_test_tokenizer() + resp_ids, eot_ids, _ = get_template_token_ids() + with pytest.raises(ValueError, match="tool_call_start_template"): + TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_prefix=resp_ids, + end_of_turn_template=eot_ids, + mask_tool_calls=True, + tool_call_start_template=None, + ) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +def test_span_no_response_template_all_masked(): + seq = [_SENTINELS[0], _SENTINELS[1], _SENTINELS[2]] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + labels = get_span_labels(make_span_collator(), seq) + + assert all(v == IGNORE for v in labels) + assert any("response template" in str(w.message).lower() for w in caught) + + +def test_span_no_eot_unmasked_to_end_of_sequence(): + resp, _, _ = get_template_token_ids() + content = [_SENTINELS[0], _SENTINELS[1]] + seq = flat(resp, content) + + labels = get_span_labels(make_span_collator(), seq) + + assert labels[len(resp) :] == content + + +def test_span_empty_content_span(): + resp, eot, _ = get_template_token_ids() + seq = flat(resp, eot) + + labels = get_span_labels(make_span_collator(), seq) + + assert all(v == IGNORE for v in labels) + + +def test_span_padding_matching_eot_does_not_false_match(): + """When pad_token_id matches the EOT token, padding must not be treated as + a real end-of-turn boundary.""" + tokenizer, pad_token_id = create_test_tokenizer() + resp, _, _ = get_template_token_ids() + + # Use pad_token_id itself as the EOT template — worst case scenario. + eot_ids = [pad_token_id] + content = [_SENTINELS[0], _SENTINELS[1]] + # Sequence: [RESP] content (no real EOT), then padding + seq = flat(resp, content) + [pad_token_id] * 5 + + collator = TextCompletionsCollatorWithPadding( + tokenizer=tokenizer, + response_prefix=resp, + end_of_turn_template=eot_ids, + ) + batch = collator([{"input_ids": seq}]) + labels = batch["labels"][0].tolist() + + # Content should be unmasked — the padding should not act as an EOT. + content_start = len(resp) + assert labels[content_start : content_start + len(content)] == content + # Padding should be masked. + assert all(v == IGNORE for v in labels[content_start + len(content) :]) + + +# --------------------------------------------------------------------------- +# Batch processing +# --------------------------------------------------------------------------- + + +def test_span_batch_two_examples_processed_independently(): + resp, eot, _ = get_template_token_ids() + _, pad_token_id = create_test_tokenizer() + content_a = [_SENTINELS[0], _SENTINELS[1]] + content_b = [_SENTINELS[2]] + seq_a = flat(resp, content_a, eot) + seq_b = flat(resp, content_b, eot) + + max_len = max(len(seq_a), len(seq_b)) + pad_a = [pad_token_id] * (max_len - len(seq_a)) + pad_b = [pad_token_id] * (max_len - len(seq_b)) + + collator = make_span_collator() + batch = collator([{"input_ids": seq_a + pad_a}, {"input_ids": seq_b + pad_b}]) + labels_a = batch["labels"][0].tolist() + labels_b = batch["labels"][1].tolist() + + assert labels_a[len(resp) : len(resp) + len(content_a)] == content_a + assert labels_b[len(resp) : len(resp) + len(content_b)] == content_b + + +def test_span_batch_bad_example_does_not_affect_others(): + resp, eot, _ = get_template_token_ids() + _, pad_token_id = create_test_tokenizer() + good_seq = flat(resp, [_SENTINELS[0]], eot) + bad_seq = [_SENTINELS[1], _SENTINELS[2]] + + max_len = max(len(good_seq), len(bad_seq)) + pad_good = [pad_token_id] * (max_len - len(good_seq)) + pad_bad = [pad_token_id] * (max_len - len(bad_seq)) + + collator = make_span_collator() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + batch = collator( + [{"input_ids": good_seq + pad_good}, {"input_ids": bad_seq + pad_bad}] + ) + + assert batch["labels"][0].tolist()[len(resp)] == _SENTINELS[0] + assert all(v == IGNORE for v in batch["labels"][1].tolist()) + + +def test_span_output_labels_is_torch_tensor(): + resp, eot, _ = get_template_token_ids() + seq = flat(resp, [_SENTINELS[0]], eot) + batch = make_span_collator()([{"input_ids": seq}]) + assert isinstance(batch["labels"], torch.Tensor) + + +def test_span_labels_shape_matches_input_ids(): + resp, eot, _ = get_template_token_ids() + seq = flat(resp, [_SENTINELS[0], _SENTINELS[1]], eot) + batch = make_span_collator()([{"input_ids": seq}]) + assert batch["labels"].shape == batch["input_ids"].shape + + +def test_span_labels_numpy_values_match_expected(): + resp, eot, _ = get_template_token_ids() + content = [_SENTINELS[0], _SENTINELS[1]] + seq = flat(resp, content, eot) + + batch = make_span_collator()([{"input_ids": seq}]) + expected = [IGNORE] * len(resp) + content + [IGNORE] * len(eot) + assert np.all(batch["labels"].numpy() == np.array([expected], dtype=np.int32))