Skip to content
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6726786
Add span-based masking to DataCollatorForCompletionOnlyLM
shanghongsim Apr 14, 2026
1c2f842
Narrow TextCompletionsCollatorWithPadding template params to str
shanghongsim Apr 14, 2026
190b909
Validate end_of_turn_template is provided for span-based masking
shanghongsim Apr 14, 2026
2948700
Extract _tokenize_template helper and move _KNOWN_MASKING_METHODS to …
shanghongsim Apr 14, 2026
ec9d609
Fix docstring: mask_tool_calls is not a valid collator_kwargs param
shanghongsim Apr 14, 2026
e9ee991
Add deprecation warning for legacy instruction-based masking
shanghongsim Apr 14, 2026
ab6e081
Infer assistant_turn_no_tools when tool_call_start_template is provided
shanghongsim Apr 14, 2026
4c081d7
Extract _resolve_masking_method to simplify __init__ logic
shanghongsim Apr 14, 2026
6682e56
Remove tool_call_start_template and assistant_turn_no_tools masking
shanghongsim Apr 16, 2026
3b92ca5
Drop legacy Llama template fallback and neaten wrapper collator
shanghongsim Apr 16, 2026
e67c3db
Inline _resolve_masking_method into __init__
shanghongsim Apr 16, 2026
71b38ee
Rename masking_method to train_target and assistant_turn to all_assis…
shanghongsim Apr 16, 2026
df0d473
Restore _collate method in TextCompletionsCollatorWithPadding
shanghongsim Apr 16, 2026
414dd4c
Extract _resolve_train_target classmethod from __init__
shanghongsim Apr 16, 2026
33a5da8
Remove ASSISTANT_TURN_NO_TOOLS from MaskingMethod enum and builder
shanghongsim Apr 16, 2026
b4583e4
Rename MaskingMethod to TrainTarget with clearer value names
shanghongsim Apr 16, 2026
c6744bb
Replace vocab-based template detection with chat template rendering
shanghongsim Apr 16, 2026
780606c
docs: clarify _resolve_collator_templates docstring
shanghongsim Apr 16, 2026
b14ab66
Allow train_target and collator_kwargs to be used together
shanghongsim Apr 16, 2026
e56f872
Rename end_of_turn to end_of_turn_template in _resolve_collator_templ…
shanghongsim Apr 16, 2026
33c768e
Move train_target collator-name validation into config __post_init__
shanghongsim Apr 16, 2026
0136af9
Reject train_target when collator_name is not set or wrong
shanghongsim Apr 16, 2026
928c274
Remove redundant tokenizer None check in train_target block
shanghongsim Apr 16, 2026
fbc7087
Clean up _resolve_train_target and TrainTarget docstring
shanghongsim Apr 16, 2026
d5d93f4
Centralize train_target resolution in the builder
shanghongsim Apr 17, 2026
a06055f
Remove redundant tests for span-based masking and legacy collator paths
shanghongsim Apr 17, 2026
0cdb2bc
Add type annotations to fix pyright errors on tokenizer.decode calls
shanghongsim Apr 17, 2026
1c1dadf
Fix pyright: use isinstance assert for tokenizer.decode return type
shanghongsim Apr 17, 2026
e30dc99
Fix YAML configs: use enum name ALL_ASSISTANT_TURNS for OmegaConf par…
shanghongsim Apr 17, 2026
1a8e1d2
Validate end_of_turn_template early when auto-detection fails for all…
shanghongsim Apr 17, 2026
48b65a1
Differentiate error messages in _resolve_collator_templates
shanghongsim Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/projects/coalm/405b_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42

training:
Expand Down
2 changes: 2 additions & 0 deletions configs/projects/coalm/70b_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42

training:
Expand Down
2 changes: 2 additions & 0 deletions configs/projects/coalm/8b_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ data:
shuffle: True
seed: 42
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42
validation:
datasets:
- dataset_name: "text_sft_jsonl"
dataset_path: "/path/to/validation/dataset.jsonl"
collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42

training:
Expand Down
2 changes: 2 additions & 0 deletions configs/projects/halloumi/8b_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ data:
seed: 42

collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42
validation:
datasets:
Expand All @@ -78,6 +79,7 @@ data:
}

collator_name: "text_completions_only_with_padding"
train_target: "all_assistant_turns"
seed: 42

training:
Expand Down
173 changes: 152 additions & 21 deletions src/oumi/builders/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from collections.abc import Callable

import oumi.core.constants as constants
Expand All @@ -27,11 +28,101 @@
from oumi.core.configs.internal.supported_models import (
find_internal_model_config,
)
from oumi.core.configs.params.data_params import TrainTarget
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.logging import logger

# This is used to set the max input length for a model with infinite size input
_VERY_LARGE_INTEGER = int(1e30)
_SENTINEL_USER = "<<__U__>>"
_SENTINEL_ASST = "<<__A__>>"
_FALLBACK_MSG = (
"Cannot auto-detect collator templates from the chat template. "
"Provide response_template (and end_of_turn_template for "
"all_assistant_turns) via collator_kwargs."
)


def _resolve_collator_templates(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been tested with gpt-oss models? They use a non-standard system for marking boundaries within responses, so wondering if boundary detection will work as it does for Qwen/Llama style templates.

tokenizer: "BaseTokenizer",
) -> tuple[str, str]:
"""Auto-detect response_template and end_of_turn_template.

Applies the chat template to a known test conversation, then finds
the assistant boundary strings in the rendered output.

Returns:
(response_template, end_of_turn_template)

Raises:
ValueError: If templates cannot be extracted.
"""
msgs = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really clever, but we should maybe include a system instruction here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like modify it to extract system instruction? We are current no longer relying on system instruction for masking

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean ensure that when we extract the different template portions, we do so using an example conversation that also includes a system instruction

Or rather, perhaps do it twice, once on a conversation with a system instruction, and once on a conversation without.

Why? Because I'm a bit concerned of the scenario where this logic doesn't use system instructions when extracting, but the user does in their data, and somehow the resulting templates we extract using this methodology don't wind up working when the data includes the presence of a system instruction for some reason.

{"role": "user", "content": _SENTINEL_USER},
{"role": "assistant", "content": _SENTINEL_ASST},
{"role": "user", "content": _SENTINEL_USER},
{"role": "assistant", "content": _SENTINEL_ASST},
]

try:
rendered = tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=False
)
except Exception as exc:
raise ValueError(_FALLBACK_MSG) from exc

if not isinstance(rendered, str):
raise ValueError(_FALLBACK_MSG)

# Locate boundaries around the second turn pair
# to avoid system-prompt effects on the first turn.
try:
a1 = rendered.index(_SENTINEL_ASST)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: first_asst_start, and add _start to second turns

first_asst_end = a1 + len(_SENTINEL_ASST)
second_user = rendered.index(_SENTINEL_USER, first_asst_end)
second_user_end = second_user + len(_SENTINEL_USER)
second_asst = rendered.index(_SENTINEL_ASST, second_user_end)
second_asst_end = second_asst + len(_SENTINEL_ASST)
except ValueError:
raise ValueError(_FALLBACK_MSG)

# End-of-turn: common token-ID prefix of the two strings that
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract this to its own method

# follow assistant content (mid-conversation vs. end-of-sequence).
after_ids = tokenizer.encode(rendered[second_asst_end:], add_special_tokens=False)
between_ids = tokenizer.encode(
rendered[first_asst_end:second_user], add_special_tokens=False
)
eot_len = 0
for a, b in zip(after_ids, between_ids):
if a != b:
break
eot_len += 1
eot_ids = after_ids[:eot_len]
end_of_turn_template: str = tokenizer.decode(eot_ids, skip_special_tokens=False)
Comment thread
shanghongsim marked this conversation as resolved.
Outdated

Comment thread
shanghongsim marked this conversation as resolved.
# Response template: strip the EOT prefix to get just the assistant header.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, extract to its own method

resp_ids = tokenizer.encode(
rendered[second_user_end:second_asst], add_special_tokens=False
)
if eot_len > 0 and resp_ids[:eot_len] == eot_ids:
resp_ids = resp_ids[eot_len:]
response_template: str = tokenizer.decode(resp_ids, skip_special_tokens=False)

if not response_template.strip():
raise ValueError(_FALLBACK_MSG)

# Qwen3 and similar reasoning models inject <think>...</think> into
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description scares me, I wonder if we could have a better workaround or a louder error

# every assistant turn via their chat template. If training data was
# formatted without thinking tokens the response_template won't match
# and every example will be silently masked.
if "<think>" in response_template:
logger.warning(
"The extracted response_template contains <think> tokens "
"(from the model's chat template). If you're training without "
"thinking tokens, use collator_kwargs to specify "
"response_template manually."
)

return response_template, end_of_turn_template


def build_data_collator(
Expand All @@ -51,7 +142,8 @@ def build_data_collator(

- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
`TextCompletionsCollatorWithPadding`.
`TextCompletionsCollatorWithPadding`. Supports optional
``end_of_turn_template`` for tool-aware span-based masking.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.

Expand Down Expand Up @@ -126,27 +218,20 @@ def build_data_collator(
**kwargs,
)
elif collator_name == "text_completions_only_with_padding":
# Extract instruction and response templates from kwargs if provided
instruction_template = kwargs.pop("instruction_template", None)
response_template = kwargs.pop("response_template", None)

# Default to Llama-style templates if not provided
instruction_prefix = (
instruction_template
if instruction_template
else "<|start_header_id|>user<|end_header_id|>\n\n"
)
response_prefix = (
response_template
if response_template
else "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
if not kwargs.get("response_template"):
raise ValueError(
"'text_completions_only_with_padding' requires a "
"response_template. Either set train_target in your config "
"(which auto-resolves templates from the tokenizer) or "
"provide response_template via collator_kwargs."
)

return TextCompletionsCollatorWithPadding(
Comment thread
shanghongsim marked this conversation as resolved.
tokenizer=tokenizer,
instruction_prefix=instruction_prefix,
response_prefix=response_prefix,
debug=debug,
ignore_index=(
label_ignore_index if label_ignore_index is not None else -100
),
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
Comment on lines 243 to 249
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The TextCompletionsCollatorWithPadding constructor can receive ignore_index both as an explicit argument and within **kwargs, causing a TypeError if a user customizes it.
Severity: HIGH

Suggested Fix

Before calling the TextCompletionsCollatorWithPadding constructor, pop ignore_index from the kwargs dictionary. Use the popped value to determine the final ignore_index value to be passed, preventing the argument duplication.

Prompt for AI Agent
Review the code at the location below. A potential bug has been identified by an AI
agent. Verify if this is a real issue. If it is, propose a fix; if not, explain why it's
not valid.

Location: src/oumi/builders/collators.py#L243-L249

Potential issue: When building the `text_completions_only_with_padding` collator in
`build_data_collator`, the `ignore_index` parameter is passed explicitly while also
being potentially present in the `**kwargs` passed to the
`TextCompletionsCollatorWithPadding` constructor. If a user specifies `ignore_index` in
their `collator_kwargs` configuration, this results in a `TypeError` because the
`__init__` method receives multiple values for the same keyword argument, causing a
runtime crash during training setup.

Expand Down Expand Up @@ -206,9 +291,55 @@ def build_collator_from_config(
"trust_remote_code", config.model.trust_remote_code
)

# Merge collator_kwargs from config with the existing kwargs
# Config kwargs take precedence over automatically determined kwargs
# --- Resolve train_target and templates ---
config_collator_kwargs = train_split.collator_kwargs or {}

if collator_name == "text_completions_only_with_padding":
if train_split.train_target is not None:
# Path 1: train_target is set, auto-detect templates from
# the tokenizer's chat template. Falls back to user-provided
# response_template in collator_kwargs if auto-detection fails.
collator_kwargs["train_target"] = train_split.train_target.value

try:
response_template, end_of_turn_template = _resolve_collator_templates(
tokenizer
)
collator_kwargs["response_template"] = response_template
if train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS:
collator_kwargs["end_of_turn_template"] = end_of_turn_template
except ValueError:
if config_collator_kwargs.get("response_template") is None:
raise
Comment thread
shanghongsim marked this conversation as resolved.

elif config_collator_kwargs.get("response_template") is not None:
# Path 2: train_target not set, templates provided manually
# via collator_kwargs. Infer train_target from which templates
# are present.
has_eot = config_collator_kwargs.get("end_of_turn_template") is not None
has_inst = config_collator_kwargs.get("instruction_template") is not None
if has_eot:
collator_kwargs["train_target"] = "all_assistant_turns"
elif has_inst:
warnings.warn(
"Instruction-based masking is deprecated. "
"Use train_target='all_assistant_turns' with "
"end_of_turn_template for multi-turn conversations, "
"or train_target='final_assistant_turn' "
"for single-turn completions.",
DeprecationWarning,
stacklevel=2,
)
collator_kwargs["train_target"] = "_legacy_instruction_response"
else:
collator_kwargs["train_target"] = "final_assistant_turn"
else:
raise ValueError(
"'text_completions_only_with_padding' requires either "
"train_target or response_template in collator_kwargs."
)

# User-provided collator_kwargs override auto-resolved values
collator_kwargs.update(config_collator_kwargs)

return build_data_collator(
Expand Down
26 changes: 19 additions & 7 deletions src/oumi/core/collators/text_completions_collator_with_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,34 @@ class TextCompletionsCollatorWithPadding:
def __init__(
self,
tokenizer: BaseTokenizer,
instruction_prefix: str,
response_prefix: str,
response_template: str,
train_target: str,
instruction_template: str | None = None,
debug: bool = False,
end_of_turn_template: str | None = None,
ignore_index: int = -100,
):
"""Custom collator for text LLM training.

Args:
tokenizer: The tokenizer used for encoding the data.
instruction_prefix: The prefix marking the beginning of the user instruction.
response_prefix: The prefix marking the beginning of the assistant response.
response_template: String marking assistant response start.
instruction_template: String marking user instruction start.
debug: If True, enables debug mode for logging.
train_target: Training target — ``"all_assistant_turns"``
or ``"final_assistant_turn"``.
end_of_turn_template: String marking the end of a turn.
Required for ``all_assistant_turns``.
ignore_index: Value used for masked labels. Must match the ignore_index
of the loss function (default: -100).
"""
self._default_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
instruction_template=instruction_prefix,
response_template=response_prefix,
instruction_template=instruction_template,
response_template=response_template,
train_target=train_target,
end_of_turn_template=end_of_turn_template,
ignore_index=ignore_index,
)

if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
Expand All @@ -55,7 +67,7 @@ def _collate(self, inputs: list[Any]) -> dict[str, Any]:
result = self._default_collator(inputs)
return result

def __call__(self, batch) -> dict[str, Any]:
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
"""Pads to the longest length present in the batch.

Args:
Expand Down
Loading
Loading