diff --git a/bionemo-recipes/models/codonfm/collator.py b/bionemo-recipes/models/codonfm/collator.py new file mode 100644 index 0000000000..8dd2087b12 --- /dev/null +++ b/bionemo-recipes/models/codonfm/collator.py @@ -0,0 +1,574 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context parallel data collation and distribution utilities. + +Provides model-agnostic utilities for context parallelism (CP): +- DataCollatorForContextParallel: Splits THD/BSHD batches into CP shards +- ContextParallelDataLoaderWrapper: Distributes shards across CP ranks +- _split_batch_by_cp_rank: Core batch splitting logic + +Adapted from bionemo-recipes/models/esm2/collator.py for use with CodonFM's +custom CodonTHDCollator. The utilities here work with any collator that produces +THD format output (input_ids, labels, cu_seq_lens_q/k). +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import nvtx +import torch +from transformers import DataCollator + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + if self._prefetch_thread is not None: + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int: + """Find which dimension of tensor matches the expected sequence length. + + Args: + tensor: The tensor to inspect. + seq_len: The expected sequence length to match against tensor dimensions. + + Returns: + The dimension index that matches the sequence length. + + Raises: + ValueError: If no dimension matches the expected sequence length. + """ + if tensor.ndim == 1: + if tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}") + elif tensor.ndim >= 2: + if tensor.shape[1] == seq_len: + return 1 + elif tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1") + raise ValueError(f"Unexpected tensor ndim={tensor.ndim}") + + +def _process_tensor_thd( + val: torch.Tensor | None, + seq_len: int, + slice_sizes: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + total_slices: int, +) -> torch.Tensor | None: + """Extract the THD context-parallel shard for a single tensor. + + For each sequence in the batch, selects two slices (one from the beginning and one from the end) + corresponding to the given CP rank, following the zigzag CP sharding pattern. + + Args: + val: The tensor to shard, or None (returned as-is). + seq_len: Total sequence length (from cu_seqlens_padded[-1]). + slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices. + cu_seqlens_padded: Cumulative sequence lengths including padding. + cp_rank: The context parallelism rank index. + total_slices: Total number of slices per sequence (2 * cp_world_size). + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + """ + if val is None: + return val + + seq_dim = _find_seq_dim(val, seq_len) + + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices - cp_rank - 1) * slice_size), + seq_start + ((total_slices - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(seq_dim, torch.cat(cp_rank_slices)) + + +def _process_tensor_bshd( + val: torch.Tensor | None, + cp_rank: int, + cp_world_size: int, +) -> torch.Tensor | None: + """Extract the BSHD context-parallel shard for a single tensor. + + Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks, + then selects the two chunks corresponding to the given CP rank (zigzag pattern). + + Args: + val: The tensor to shard, or None (returned as-is). + cp_rank: The context parallelism rank index. + cp_world_size: Total number of context parallelism ranks. + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + + Raises: + ValueError: If the tensor has fewer than 2 dimensions or its sequence length + is not divisible by 2 * cp_world_size. + """ + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if seq_len % total_chunks != 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. + cp_rank: Optional manual CP rank index. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + last_elem = cu_seqlens_padded[-1] + seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem + + input_ids_padded = _process_tensor_thd( + input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + labels_padded = _process_tensor_thd( + labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + + elif qvk_format == "bshd": + input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size) + labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary for THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/models/codonfm/dataset.py b/bionemo-recipes/models/codonfm/dataset.py index f114df98d1..0acc37c7e4 100644 --- a/bionemo-recipes/models/codonfm/dataset.py +++ b/bionemo-recipes/models/codonfm/dataset.py @@ -162,6 +162,7 @@ def __init__( max_seq_length: int = 512, mlm_probability: float = 0.15, seed: int = 42, + pad_sequences_to_be_divisible_by: int | None = None, ): """Initialize. @@ -170,11 +171,14 @@ def __init__( max_seq_length: Maximum sequence length per sample. mlm_probability: Probability of masking a token. seed: Random seed for reproducible masking. + pad_sequences_to_be_divisible_by: If set, each individual sequence is padded + to be divisible by this value. Used for context parallelism. """ self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.mlm_probability = mlm_probability self.rng = random.Random(seed) + self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: """Collate a batch into THD packed format. @@ -216,15 +220,40 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: cu_seq_lens = torch.zeros(len(seq_lengths) + 1, dtype=torch.int32) cu_seq_lens[1:] = torch.cumsum(torch.tensor(seq_lengths, dtype=torch.int32), dim=0) - return { - "input_ids": torch.tensor(all_ids, dtype=torch.long).unsqueeze(0), - "labels": torch.tensor(all_labels, dtype=torch.long).unsqueeze(0), + input_ids = torch.tensor(all_ids, dtype=torch.long).unsqueeze(0) + labels_tensor = torch.tensor(all_labels, dtype=torch.long).unsqueeze(0) + + result = { + "input_ids": input_ids, + "labels": labels_tensor, "cu_seq_lens_q": cu_seq_lens, "cu_seq_lens_k": cu_seq_lens, "max_length_q": max(seq_lengths), "max_length_k": max(seq_lengths), } + # Per-sequence padding for context parallelism + if self.pad_sequences_to_be_divisible_by is not None: + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + pad_thd_sequences_for_cp, + ) + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.squeeze(0), + labels_tensor.squeeze(0), + cu_seq_lens, + self.pad_sequences_to_be_divisible_by, + padding_token_id=self.tokenizer.pad_token_id, + padding_label_id=-100, + ) + result["input_ids"] = input_ids_padded.unsqueeze(0) + result["labels"] = labels_padded.unsqueeze(0) + result["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + result["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + result["pad_between_seqs"] = True + + return result + def create_bshd_dataloader( dist_config: DistributedConfig, diff --git a/bionemo-recipes/models/codonfm/tests/test_cp_bshd.py b/bionemo-recipes/models/codonfm/tests/test_cp_bshd.py new file mode 100644 index 0000000000..9dc7d72fa7 --- /dev/null +++ b/bionemo-recipes/models/codonfm/tests/test_cp_bshd.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context parallel equivalence test for CodonFM in BSHD format. + +Verifies that running the model with context parallelism (CP=2) produces +equivalent losses, logits, and gradients compared to a non-distributed run. +""" + +import os +import subprocess +import sys +import tempfile +from dataclasses import dataclass, field +from pathlib import Path + + +# When launched via torchrun, conftest.py sys.path setup doesn't run. +# Ensure the model directory (parent of tests/) is on sys.path for bare module imports. +sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix()) + +import pytest +import torch +from collator import _split_batch_by_cp_rank +from dataset import CodonMLMCollator +from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM +from tokenizer import CodonTokenizer +from torch.distributed.device_mesh import init_device_mesh + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +# Test codon sequences (DNA strings of 3-mer codons) - chosen to produce 32-token sequences with special tokens +TEST_CODON_SEQUENCES = [ + "ATGCGTAAAGCTGTTCAGGATCTGAATGCCATCTATGCGATGCGTAAAGCTGTTCAGGATCTGAATGCCATCTATGCG", + "ATGGATCGTACCGCTGAACAGCGTCTGATCAAAGCCATGGATCGTACCGCTGAACAGCGTCTGATCAAAGCCATGGAT", +] + + +def get_dummy_data_bshd_with_padding(): + """Create dummy BSHD data with padding for CP. + + Returns: + A dictionary containing the padded input ids and labels in BSHD format [batch, seq_len]. + """ + tokenizer = CodonTokenizer() + collator = CodonMLMCollator( + tokenizer=tokenizer, + max_seq_length=32, # Pad to 32 for CP divisibility (32 / (2*2) = 8) + mlm_probability=0.0, # No masking for deterministic testing + ) + samples = [{"sequence": seq} for seq in TEST_CODON_SEQUENCES] + batch = collator(samples) + batch["labels"] = batch["input_ids"].clone() # Identity for testing CP sanity. + del batch["attention_mask"] + return batch + + +def create_model_checkpoint(tmp_path): + """Create a small CodonFM model checkpoint. + + Args: + tmp_path: The path to save the model checkpoint. + + Returns: + The path to the saved model checkpoint. + """ + config = CodonFMConfig( + attn_input_format="bshd", + dtype=torch.bfloat16, + **MODEL_PRESETS["encodon_200k"], + ) + model = CodonFMForMaskedLM(config) + model.save_pretrained(tmp_path / "codonfm_checkpoint") + return tmp_path / "codonfm_checkpoint" + + +def get_batch_for_cp_rank(batch, cp_rank, cp_world_size): + """Get a batch shard for a given context parallelism rank. + + Args: + batch: The batch to get a shard of. + cp_rank: The context parallelism rank. + cp_world_size: The size of the context parallelism group. + + Returns: + A dictionary containing the shard of the batch. + """ + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=None, + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format="bshd", + cp_rank=cp_rank, + cp_world_size=cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + batch_shard["labels"] = labels_sharded + return batch_shard + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process.""" + return self.rank == 0 + + +def test_context_parallel_equivalence_1process(): + """Smoke test to ensure CP works with 1 process and matches non-distributed results.""" + cmd = [ + "torchrun", + "--nproc_per_node=1", + os.path.relpath(__file__), + ] + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +@requires_multi_gpu +def test_context_parallel_equivalence_2process(): + """Test CP equivalence between 2 processes. + + Compares (1) Losses, (2) Logits, and (3) Gradients from distributed CP vs non-distributed runs. + """ + cmd = [ + "torchrun", + "--nproc_per_node=2", + os.path.relpath(__file__), + ] + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +if __name__ == "__main__": + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + model_ckpt = create_model_checkpoint(tmp_path) + + input_data_bshd_padded = get_dummy_data_bshd_with_padding() + + config = CodonFMConfig.from_pretrained(model_ckpt) + config.attn_input_format = "bshd" + config.dtype = torch.bfloat16 + model = CodonFMForMaskedLM(config) + model.load_state_dict(torch.load(model_ckpt / "model.safetensors", weights_only=True), strict=False) + model = model.to(dtype=torch.bfloat16, device="cuda") + model.train() + + input_data_bshd_padded = { + k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_bshd_padded.items() + } + outputs_nondistributed = model(**input_data_bshd_padded) + loss_nondistributed = outputs_nondistributed.loss + loss_nondistributed.backward() + + # Clone everything we need for later comparison BEFORE deleting + loss_nondistributed_for_comparison = loss_nondistributed.detach().clone().cpu() + logits_nondistributed_for_comparison = outputs_nondistributed.logits.detach().clone().cpu() + + # Sample gradients from a few layers for comparison + sample_layers = [ + model.encoder.layers[0].self_attention.core_attention, + model.encoder.layers[0].self_attention.layernorm_qkv, + ] + + gradients_nondistributed = {} + for i, layer in enumerate(sample_layers): + for name, param in layer.named_parameters(): + if param.grad is not None: + key = f"layer_{i}.{name}" + gradients_nondistributed[key] = param.grad.detach().clone().cpu() + + # Now setup distributed training for CP. + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + + # Clean up everything from non-distributed run + del model, outputs_nondistributed, loss_nondistributed, input_data_bshd_padded + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Initialize distributed training + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + ddp_size = 1 + cp_size = torch.distributed.get_world_size() + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(ddp_size, cp_size), + mesh_dim_names=("ddp", "cp"), + ) + + # Re-initialize the model on the new device + config = CodonFMConfig.from_pretrained(model_ckpt) + config.attn_input_format = "bshd" + config.dtype = torch.bfloat16 + model = CodonFMForMaskedLM(config) + model.load_state_dict(torch.load(model_ckpt / "model.safetensors", weights_only=True), strict=False) + model = model.to(dtype=torch.bfloat16, device=device) + model.train() + model.zero_grad(set_to_none=True) + + group_fsdp_cp = device_mesh[("ddp", "cp")]._flatten("dp_cp").get_group() + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + process_group=group_fsdp_cp, + ) + cp_group = device_mesh["cp"].get_group() + cp_rank = device_mesh.get_local_rank("cp") + cp_world_size = torch.distributed.get_world_size(group=cp_group) + + # Set up context parallelism for each layer + for transformer_layer in model.module.encoder.layers: + transformer_layer.set_context_parallel_group( + cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() + ) + + model.zero_grad(set_to_none=True) + + # Create FRESH batch data for CP + batch = get_dummy_data_bshd_with_padding() + batch = {k: v.detach().to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cp = get_batch_for_cp_rank(batch, cp_rank=cp_rank, cp_world_size=cp_world_size) + batch_cp["max_length_q"] = batch_cp["max_length_k"] = 32 + + torch.distributed.barrier(group=cp_group) + + outputs_cp = model(**batch_cp) + + # Gather losses from all cp ranks + losses_list = [torch.zeros_like(outputs_cp.loss) for _ in range(cp_world_size)] + torch.distributed.all_gather(losses_list, outputs_cp.loss, group=cp_group) + + if cp_rank == 0: + average_cp_loss = torch.mean(torch.stack(losses_list)) + torch.testing.assert_close( + average_cp_loss.cpu(), + loss_nondistributed_for_comparison, + atol=0.1, + rtol=0.05, + ) + + # Gather logits from all CP ranks + logits_contiguous = outputs_cp.logits.contiguous() + logits_list = [torch.zeros_like(logits_contiguous) for _ in range(cp_world_size)] + torch.distributed.all_gather(logits_list, logits_contiguous, group=cp_group) + + if cp_rank == 0: + # Reconstruct the full logits from CP-split chunks for BSHD format + batch_size, seq_len_sharded, vocab_size = logits_list[0].shape + seq_len_full = batch["input_ids"].shape[1] + total_chunks = 2 * cp_world_size + chunk_size = seq_len_full // total_chunks + + reconstructed_logits = torch.zeros( + (batch_size, seq_len_full, vocab_size), dtype=torch.bfloat16, device=logits_list[0].device + ) + + for batch_idx in range(batch_size): + for cp_idx, logits_shard in enumerate(logits_list): + chunk_indices = [cp_idx, total_chunks - cp_idx - 1] + for chunk_pos, chunk_idx in enumerate(chunk_indices): + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + shard_start = chunk_pos * chunk_size + shard_end = shard_start + chunk_size + reconstructed_logits[batch_idx, start_idx:end_idx, :] = logits_shard[ + batch_idx, shard_start:shard_end, : + ] + + assert reconstructed_logits.shape == logits_nondistributed_for_comparison.shape + torch.testing.assert_close( + reconstructed_logits.cpu(), + logits_nondistributed_for_comparison, + atol=0.29, + rtol=0.01, + ) + + # Test gradient synchronization with DDP + outputs_cp.loss.backward() + + sample_layers_cp = [ + model.module.encoder.layers[0].self_attention.core_attention, + model.module.encoder.layers[0].self_attention.layernorm_qkv, + ] + + gradients_cp = {} + for i, layer in enumerate(sample_layers_cp): + for name, param in layer.named_parameters(): + if param.grad is not None: + key = f"layer_{i}.{name}" + gradients_cp[key] = param.grad.detach().clone().cpu() + + if cp_rank == 0: + for key in gradients_nondistributed.keys(): + if key in gradients_cp: + grad_cp = gradients_cp[key] + grad_nondist = gradients_nondistributed[key] + + torch.testing.assert_close( + grad_cp, + grad_nondist, + atol=2e-3, + rtol=1e-2, + msg=lambda x: f"Gradients don't match for {key}: {x}", + ) + + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/models/codonfm/tests/test_cp_thd.py b/bionemo-recipes/models/codonfm/tests/test_cp_thd.py new file mode 100644 index 0000000000..073c401cfb --- /dev/null +++ b/bionemo-recipes/models/codonfm/tests/test_cp_thd.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context parallel equivalence test for CodonFM in THD format. + +Verifies that running the model with context parallelism (CP=2) produces +equivalent losses, logits, and gradients compared to a non-distributed run. +""" + +import os +import subprocess +import sys +import tempfile +from dataclasses import dataclass, field +from pathlib import Path + + +# When launched via torchrun, conftest.py sys.path setup doesn't run. +# Ensure the model directory (parent of tests/) is on sys.path for bare module imports. +sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix()) + +import pytest +import torch +from collator import _split_batch_by_cp_rank +from dataset import CodonTHDCollator +from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM +from tokenizer import CodonTokenizer +from torch.distributed.device_mesh import init_device_mesh + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + +# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. +requires_datacenter_hardware = pytest.mark.skipif( + not torch.cuda.is_available() + or not any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ), + reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", +) + + +# Test codon sequences (DNA strings of 3-mer codons) +TEST_CODON_SEQUENCES = [ + "ATGCGTAAAGCTGTTCAGGATCTGAATGCCATCTATGCGATGCGTAAAGCTGTTCAGGATCTGAATGCCATCTATGCG", + "ATGGATCGTACCGCTGAACAGCGTCTGATCAAAGCCATGGATCGTACCGCTGAACAGCGTCTGATCAAAGCCATGGAT", +] + + +def get_dummy_data_thd_with_padding(cp_size: int = 2): + """Create dummy THD data with per-sequence padding for CP. + + Args: + cp_size: Context parallel size, determines padding divisibility. + + Returns: + A dictionary containing the padded input ids, labels, and cu seq lens. + """ + tokenizer = CodonTokenizer() + collator = CodonTHDCollator( + tokenizer=tokenizer, + max_seq_length=512, + mlm_probability=0.0, # No masking for deterministic testing + pad_sequences_to_be_divisible_by=2 * cp_size, + ) + samples = [{"sequence": seq} for seq in TEST_CODON_SEQUENCES] + batch = collator(samples) + batch["labels"] = batch["input_ids"].clone() # Identity for testing CP sanity. + return batch + + +def create_model_checkpoint(tmp_path): + """Create a small CodonFM model checkpoint. + + Args: + tmp_path: The path to save the model checkpoint. + + Returns: + The path to the saved model checkpoint. + """ + config = CodonFMConfig( + attn_input_format="thd", + dtype=torch.bfloat16, + **MODEL_PRESETS["encodon_200k"], + ) + model = CodonFMForMaskedLM(config) + model.save_pretrained(tmp_path / "codonfm_checkpoint") + return tmp_path / "codonfm_checkpoint" + + +def get_batch_for_cp_rank(batch, cp_rank, cp_world_size): + """Get a batch shard for a given context parallelism rank. + + Args: + batch: The batch to get a shard of. + cp_rank: The context parallelism rank. + cp_world_size: The size of the context parallelism group. + + Returns: + A dictionary containing the shard of the batch. + """ + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch["cu_seq_lens_q_padded"], + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format="thd", + cp_rank=cp_rank, + cp_world_size=cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + batch_shard["labels"] = labels_sharded + # Determine the max length of the sequence. + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64) # From TE code. + batch_shard["max_length_k"] = batch_shard["max_length_q"] + return batch_shard + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process.""" + return self.rank == 0 + + +@requires_multi_gpu +@requires_datacenter_hardware +def test_context_parallel_equivalence_2process(): + """Test CP equivalence between 2 processes. + + Compares (1) Losses, (2) Logits, and (3) Gradients from distributed CP vs non-distributed runs. + """ + cmd = [ + "torchrun", + "--nproc_per_node=2", + os.path.relpath(__file__), + ] + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +if __name__ == "__main__": + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + model_ckpt = create_model_checkpoint(tmp_path) + + input_data_thd_padded = get_dummy_data_thd_with_padding() + + config = CodonFMConfig.from_pretrained(model_ckpt) + config.attn_input_format = "thd" + config.dtype = torch.bfloat16 + model = CodonFMForMaskedLM(config) + model.load_state_dict(torch.load(model_ckpt / "model.safetensors", weights_only=True), strict=False) + model = model.to(dtype=torch.bfloat16, device="cuda") + model.train() + + input_data_thd_padded = { + k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd_padded.items() + } + outputs_nondistributed = model(**input_data_thd_padded) + loss_nondistributed = outputs_nondistributed.loss + loss_nondistributed.backward() + + # Clone everything we need for later comparison BEFORE deleting + loss_nondistributed_for_comparison = loss_nondistributed.detach().clone().cpu() + logits_nondistributed_for_comparison = outputs_nondistributed.logits.detach().clone().cpu() + + # Sample gradients from a few layers for comparison + sample_layers = [ + model.encoder.layers[0].self_attention.core_attention, + model.encoder.layers[0].self_attention.layernorm_qkv, + ] + + gradients_nondistributed = {} + for i, layer in enumerate(sample_layers): + for name, param in layer.named_parameters(): + if param.grad is not None: + key = f"layer_{i}.{name}" + gradients_nondistributed[key] = param.grad.detach().clone().cpu() + + # Now setup distributed training for CP. + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + + # Clean up everything from non-distributed run + del model, outputs_nondistributed, loss_nondistributed, input_data_thd_padded + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Initialize distributed training + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + ddp_size = 1 + cp_size = 2 + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(ddp_size, cp_size), + mesh_dim_names=("ddp", "cp"), + ) + + # Re-initialize the model on the new device + config = CodonFMConfig.from_pretrained(model_ckpt) + config.attn_input_format = "thd" + config.dtype = torch.bfloat16 + model = CodonFMForMaskedLM(config) + model.load_state_dict(torch.load(model_ckpt / "model.safetensors", weights_only=True), strict=False) + model = model.to(dtype=torch.bfloat16, device=device) + model.train() + model.zero_grad(set_to_none=True) + + group_fsdp_cp = device_mesh[("ddp", "cp")]._flatten("dp_cp").get_group() + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + process_group=group_fsdp_cp, + ) + cp_group = device_mesh["cp"].get_group() + cp_rank = device_mesh.get_local_rank("cp") + cp_world_size = torch.distributed.get_world_size(group=cp_group) + + # Set up context parallelism for each layer + for transformer_layer in model.module.encoder.layers: + transformer_layer.set_context_parallel_group( + cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() + ) + + model.zero_grad(set_to_none=True) + + # Create FRESH batch data for CP + batch = get_dummy_data_thd_with_padding() + batch = {k: v.detach().to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cp = get_batch_for_cp_rank(batch, cp_rank=cp_rank, cp_world_size=cp_world_size) + + torch.distributed.barrier(group=cp_group) + + outputs_cp = model(**batch_cp) + + # Gather losses from all cp ranks + losses_list = [torch.zeros_like(outputs_cp.loss) for _ in range(cp_world_size)] + torch.distributed.all_gather(losses_list, outputs_cp.loss, group=cp_group) + + if cp_rank == 0: + average_cp_loss = torch.mean(torch.stack(losses_list)) + torch.testing.assert_close( + average_cp_loss.cpu(), + loss_nondistributed_for_comparison, + atol=0.1, + rtol=0.05, + ) + + # Gather logits from all CP ranks + logits_contiguous = outputs_cp.logits.contiguous() + logits_list = [torch.zeros_like(logits_contiguous) for _ in range(cp_world_size)] + torch.distributed.all_gather(logits_list, logits_contiguous, group=cp_group) + + if cp_rank == 0: + # Reconstruct the full logits from CP-split chunks dynamically + cu_seqlens = batch["cu_seq_lens_q_padded"].cpu() + num_seqs = len(cu_seqlens) - 1 + total_tokens = int(cu_seqlens[-1].item()) + vocab_size = logits_nondistributed_for_comparison.shape[-1] + + reconstructed_logits = torch.zeros((total_tokens, vocab_size), dtype=torch.bfloat16) + + cp_offset_rank0 = 0 + cp_offset_rank1 = 0 + + for seq_idx in range(num_seqs): + seq_start = int(cu_seqlens[seq_idx].item()) + seq_end = int(cu_seqlens[seq_idx + 1].item()) + seq_len = seq_end - seq_start + chunk_size = seq_len // (2 * cp_world_size) + + for chunk_idx in range(2 * cp_world_size): + chunk_start_in_seq = seq_start + chunk_idx * chunk_size + chunk_end_in_seq = chunk_start_in_seq + chunk_size + + if chunk_idx == 0 or chunk_idx == 3: # Chunks for CP rank 0 + reconstructed_logits[chunk_start_in_seq:chunk_end_in_seq, :] = logits_list[0][ + cp_offset_rank0 : cp_offset_rank0 + chunk_size, : + ] + cp_offset_rank0 += chunk_size + else: # Chunks 1, 2 for CP rank 1 + reconstructed_logits[chunk_start_in_seq:chunk_end_in_seq, :] = logits_list[1][ + cp_offset_rank1 : cp_offset_rank1 + chunk_size, : + ] + cp_offset_rank1 += chunk_size + + assert reconstructed_logits.shape == logits_nondistributed_for_comparison.shape + cosine_sim = torch.nn.functional.cosine_similarity( + reconstructed_logits.flatten().float().cuda(), + logits_nondistributed_for_comparison.flatten().float().cuda(), + dim=0, + ) + assert cosine_sim > 0.99, f"Logits cosine similarity too low: {cosine_sim}" + + # Test gradient synchronization with DDP + outputs_cp.loss.backward() + + sample_layers_cp = [ + model.module.encoder.layers[0].self_attention.core_attention, + model.module.encoder.layers[0].self_attention.layernorm_qkv, + ] + + gradients_cp = {} + for i, layer in enumerate(sample_layers_cp): + for name, param in layer.named_parameters(): + if param.grad is not None: + key = f"layer_{i}.{name}" + gradients_cp[key] = param.grad.detach().clone().cpu() + + if cp_rank == 0: + for key in gradients_nondistributed.keys(): + if key in gradients_cp: + grad_cp = gradients_cp[key] + grad_nondist = gradients_nondistributed[key] + + cosine_sim = torch.nn.functional.cosine_similarity( + grad_cp.flatten().float(), grad_nondist.flatten().float(), dim=0 + ) + assert cosine_sim > 0.8, f"Gradient cosine similarity too low for {key}: {cosine_sim}" + + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/codonfm_native_te/collator.py b/bionemo-recipes/recipes/codonfm_native_te/collator.py new file mode 100644 index 0000000000..8dd2087b12 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/collator.py @@ -0,0 +1,574 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context parallel data collation and distribution utilities. + +Provides model-agnostic utilities for context parallelism (CP): +- DataCollatorForContextParallel: Splits THD/BSHD batches into CP shards +- ContextParallelDataLoaderWrapper: Distributes shards across CP ranks +- _split_batch_by_cp_rank: Core batch splitting logic + +Adapted from bionemo-recipes/models/esm2/collator.py for use with CodonFM's +custom CodonTHDCollator. The utilities here work with any collator that produces +THD format output (input_ids, labels, cu_seq_lens_q/k). +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import nvtx +import torch +from transformers import DataCollator + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + if self._prefetch_thread is not None: + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int: + """Find which dimension of tensor matches the expected sequence length. + + Args: + tensor: The tensor to inspect. + seq_len: The expected sequence length to match against tensor dimensions. + + Returns: + The dimension index that matches the sequence length. + + Raises: + ValueError: If no dimension matches the expected sequence length. + """ + if tensor.ndim == 1: + if tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}") + elif tensor.ndim >= 2: + if tensor.shape[1] == seq_len: + return 1 + elif tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1") + raise ValueError(f"Unexpected tensor ndim={tensor.ndim}") + + +def _process_tensor_thd( + val: torch.Tensor | None, + seq_len: int, + slice_sizes: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + total_slices: int, +) -> torch.Tensor | None: + """Extract the THD context-parallel shard for a single tensor. + + For each sequence in the batch, selects two slices (one from the beginning and one from the end) + corresponding to the given CP rank, following the zigzag CP sharding pattern. + + Args: + val: The tensor to shard, or None (returned as-is). + seq_len: Total sequence length (from cu_seqlens_padded[-1]). + slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices. + cu_seqlens_padded: Cumulative sequence lengths including padding. + cp_rank: The context parallelism rank index. + total_slices: Total number of slices per sequence (2 * cp_world_size). + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + """ + if val is None: + return val + + seq_dim = _find_seq_dim(val, seq_len) + + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices - cp_rank - 1) * slice_size), + seq_start + ((total_slices - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(seq_dim, torch.cat(cp_rank_slices)) + + +def _process_tensor_bshd( + val: torch.Tensor | None, + cp_rank: int, + cp_world_size: int, +) -> torch.Tensor | None: + """Extract the BSHD context-parallel shard for a single tensor. + + Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks, + then selects the two chunks corresponding to the given CP rank (zigzag pattern). + + Args: + val: The tensor to shard, or None (returned as-is). + cp_rank: The context parallelism rank index. + cp_world_size: Total number of context parallelism ranks. + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + + Raises: + ValueError: If the tensor has fewer than 2 dimensions or its sequence length + is not divisible by 2 * cp_world_size. + """ + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if seq_len % total_chunks != 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. + cp_rank: Optional manual CP rank index. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + last_elem = cu_seqlens_padded[-1] + seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem + + input_ids_padded = _process_tensor_thd( + input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + labels_padded = _process_tensor_thd( + labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + + elif qvk_format == "bshd": + input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size) + labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary for THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/recipes/codonfm_native_te/dataset.py b/bionemo-recipes/recipes/codonfm_native_te/dataset.py index c9b3652745..55f521239b 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/dataset.py +++ b/bionemo-recipes/recipes/codonfm_native_te/dataset.py @@ -16,6 +16,7 @@ """Dataset and dataloader utilities for CodonFM pretraining.""" import json +import logging import random from pathlib import Path @@ -27,6 +28,9 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler +logger = logging.getLogger(__name__) + + BASES = "ACGT" @@ -250,6 +254,7 @@ def __init__( mlm_probability: float = 0.15, seed: int = 42, pad_to_multiple_of: int | None = None, + pad_sequences_to_be_divisible_by: int | None = None, ): """Initialize. @@ -260,12 +265,18 @@ def __init__( seed: Random seed for reproducible masking. pad_to_multiple_of: If set, pad total tokens to a multiple of this value. Required for FP8 (8), MXFP8 (16), or NVFP4 (32) with THD format. + pad_sequences_to_be_divisible_by: If set, each individual sequence is padded + to be divisible by this value. Used for context parallelism. + Cannot be used together with pad_to_multiple_of. """ + if pad_sequences_to_be_divisible_by is not None and pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.mlm_probability = mlm_probability self.rng = random.Random(seed) self.pad_to_multiple_of = pad_to_multiple_of + self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: """Collate a batch into THD packed format. @@ -334,7 +345,7 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: cu_seq_lens = torch.cat([cu_seq_lens, torch.tensor(pad_cu_lens, dtype=cu_seq_lens.dtype)]) max_length = max(max_length, min(remainder, self.max_seq_length)) - return { + result = { "input_ids": input_ids, "labels": labels, "cu_seq_lens_q": cu_seq_lens, @@ -343,6 +354,29 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]: "max_length_k": max_length, } + # Per-sequence padding for context parallelism: each sequence is padded individually + # so its length is divisible by pad_sequences_to_be_divisible_by. + if self.pad_sequences_to_be_divisible_by is not None: + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + pad_thd_sequences_for_cp, + ) + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.squeeze(0), + labels.squeeze(0), + cu_seq_lens, + self.pad_sequences_to_be_divisible_by, + padding_token_id=self.tokenizer.pad_token_id, + padding_label_id=-100, + ) + result["input_ids"] = input_ids_padded.unsqueeze(0) + result["labels"] = labels_padded.unsqueeze(0) + result["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + result["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + result["pad_between_seqs"] = True + + return result + def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset: """Create the appropriate dataset based on data_path format. @@ -480,3 +514,83 @@ def create_thd_dataloader( ) return dataloader, sampler + + +def create_cp_dataloader( + dist_config: DistributedConfig, + *, + cp_mesh: torch.distributed.device_mesh.DeviceMesh, + data_path: str, + micro_batch_size: int = 2, + max_seq_length: int = 512, + mlm_probability: float = 0.15, + num_workers: int = 1, + seed: int = 42, + pad_to_multiple_of: int | None = None, + pad_sequences_to_be_divisible_by: int | None = None, +) -> tuple: + """Create a Context-parallel aware THD dataloader. + + Wraps the THD dataloader with CP-aware collation and distribution across ranks. + Only CP rank 0 loads data; other ranks receive shards via scatter. + + Args: + dist_config: Distributed configuration. + cp_mesh: The context parallel mesh. + data_path: Path to parquet file, memmap directory, or 'synthetic'. + micro_batch_size: Number of sequences to pack per batch. + max_seq_length: Maximum sequence length per sample. + mlm_probability: MLM masking probability. + num_workers: Number of dataloader workers. + seed: Random seed. + pad_to_multiple_of: Unused when pad_sequences_to_be_divisible_by is set. + pad_sequences_to_be_divisible_by: Per-sequence padding divisor for CP. + Defaults to cp_mesh.size() * 2 if not provided. + + Returns: + Tuple of (ContextParallelDataLoaderWrapper, DistributedSampler or None). + """ + from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel + + # Ensure pad_sequences_to_be_divisible_by is set for CP + if pad_sequences_to_be_divisible_by is None: + logger.info("pad_sequences_to_be_divisible_by not provided, using cp_mesh.size() * 2") + pad_sequences_to_be_divisible_by = cp_mesh.size() * 2 + + if cp_mesh.get_local_rank() == 0: + tokenizer = CodonTokenizer() + dataset = _create_dataset(data_path, max_seq_length, seed) + + sampler = DistributedSampler( + dataset, + rank=dist_config.rank, + num_replicas=dist_config.world_size, + seed=seed, + ) + + collator = CodonTHDCollator( + tokenizer=tokenizer, + max_seq_length=max_seq_length, + mlm_probability=mlm_probability, + pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by, + ) + + train_dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=micro_batch_size, + collate_fn=collator, + num_workers=num_workers, + pin_memory=True, + ) + + # Wrap collator with CP-aware splitting + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + device_mesh=cp_mesh, + ) + else: + train_dataloader = None + sampler = None + + return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), sampler diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity_cp.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity_cp.yaml new file mode 100644 index 0000000000..5da1db9833 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity_cp.yaml @@ -0,0 +1,36 @@ +defaults: + - defaults + - _self_ + +# Training config +model_preset: encodon_200k +num_train_steps: 250 + +# Whether to use context parallelism or not. +cp_size: 2 + +use_sequence_packing: true +dataset: + data_path: train.parquet + micro_batch_size: 2 + num_workers: 0 + max_seq_length: 512 + pad_sequences_to_be_divisible_by: 16 + +# WandB config +wandb_init_args: + name: "codonfm_native_te_cp_sanity" + mode: "offline" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 100 + +checkpoint: + ckpt_dir: null + resume_from_checkpoint: true + save_every_n_steps: 50 + save_final_model: false + +logger: + frequency: 1 diff --git a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml index 3a97660834..d21d2d2115 100644 --- a/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml @@ -6,6 +6,7 @@ grad_acc_steps: 1 # Gradient accumulation steps. Effective batch = micro_batch_ use_meta_device: true use_sequence_packing: false +cp_size: 1 dataset: data_path: ??? diff --git a/bionemo-recipes/recipes/codonfm_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/codonfm_native_te/tests/test_train_two_gpu.py new file mode 100644 index 0000000000..dfe7770b9a --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/tests/test_train_two_gpu.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These tests don't check convergence, they just check that the CP training script runs successfully on multiple GPUs. + +import subprocess + +import pytest +import torch + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + +# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. +requires_datacenter_hardware = pytest.mark.skipif( + not torch.cuda.is_available() + or not any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ), + reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", +) + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors.""" + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_multi_gpu +@requires_datacenter_hardware +def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path): + run_train_cmd( + [ + "torchrun", + "--nproc_per_node=2", + "train_fsdp2_cp.py", + "--config-name", + "L0_sanity_cp", + "num_train_steps=4", + "cp_size=2", + ], + recipe_path, + ) diff --git a/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2_cp.py new file mode 100644 index 0000000000..61a0972041 --- /dev/null +++ b/bionemo-recipes/recipes/codonfm_native_te/train_fsdp2_cp.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FSDP2 + Context Parallel training script for CodonFM with TransformerEngine layers.""" + +import logging +from contextlib import nullcontext +from pathlib import Path + +import hydra +import nvdlfw_inspect.api as debug_api +import torch +from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint +from dataset import create_cp_dataloader +from distributed_config import DistributedConfig +from modeling_codonfm_te import MODEL_PRESETS, CodonFMConfig, CodonFMForMaskedLM +from omegaconf import DictConfig, OmegaConf +from perf_logger import PerfLogger +from quantization import WandBQuantLogger, initialize_quant_stats_logging, resolve_layer_precision +from scheduler import get_linear_schedule_with_warmup +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from torch.optim import AdamW +from transformer_engine.common.recipe import Format + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train CodonFM with TE layers using FSDP2 + Context Parallelism. + + Returns: + float: The minimum loss value seen during training. + """ + logging.getLogger("httpx").setLevel(logging.WARNING) + + # Initialize distributed configuration + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # Validate that world_size is divisible by cp_size + if dist_config.world_size % args.cp_size != 0: + raise ValueError( + f"world_size ({dist_config.world_size}) must be divisible by cp_size ({args.cp_size}). " + f"Set cp_size to a divisor of world_size." + ) + + # Calculate DP size (number of data parallel replicas) + dp_size = dist_config.world_size // args.cp_size + + # Create a device mesh for DP and CP. + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dp_size, args.cp_size), + mesh_dim_names=("dp", "cp"), + ) + + # Our flattened group must have at least 2 ranks to enable Context Parallelism. + if dp_size * args.cp_size <= 1: + cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") + else: + cp_dp_mesh = device_mesh + + logger.info( + f"Creating device mesh: world_size={dist_config.world_size}, dp_size={dp_size}, cp_size={args.cp_size}" + ) + + perf_logger = None + try: + # Build model config from preset + preset_overrides = MODEL_PRESETS[args.model_preset] + + # Resolve layer-wise quantization assignments + num_layers = preset_overrides.get("num_hidden_layers", 12) + layer_precision = resolve_layer_precision( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + + # Initialize quant stats logging if enabled + if args.quant_stats_config.enabled: + wandb_logger = None + if args.quant_stats_config.log_to_wandb and dist_config.is_main_process(): + wandb_logger = WandBQuantLogger() + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + statistics_logger=wandb_logger, + ) + + # Create quantization recipes + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with FSDP2+CP. Use train_fsdp2.py instead.") + + # Context Parallelism requires THD Sequence Packing. + assert args.use_sequence_packing, "Context Parallelism requires THD Sequence Packing." + + config = CodonFMConfig( + attn_input_format="thd", + max_position_embeddings=args.dataset.max_seq_length, + layer_precision=layer_precision, + **preset_overrides, + ) + + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = CodonFMForMaskedLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + + logger.info("Initialized Model:\n%s", model) + + # Apply FSDP2 sharding with CP-aware mesh + for layer in model.encoder.layers: + fully_shard(layer, mesh=cp_dp_mesh) + # Set CP group for layer if CP is enabled. + if args.cp_size > 1: + logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {layer}") + layer.set_context_parallel_group( + device_mesh["cp"].get_group(), + torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), + torch.cuda.Stream(), + ) + fully_shard(model, mesh=cp_dp_mesh) + + # Initialize weights from meta device + if args.use_meta_device: + model.init_empty_weights() + + # Assign layer names for debug API + if args.quant_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + + # Create optimizer and scheduler + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + # Create CP dataloader + dataloader_kwargs = OmegaConf.to_container(args.dataset, resolve=True) + train_dataloader, sampler = create_cp_dataloader( + dist_config, + cp_mesh=device_mesh["cp"], + **dataloader_kwargs, + ) + + # Resume from checkpoint if available + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2_cp" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + model, optimizer, scheduler, start_step, epoch = load_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + ) + else: + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args) + + # Training loop + step = start_step + while step < args.num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + # Forward pass + outputs = model(**batch) + + # Backward pass + loss = outputs.loss + loss.backward() + + # Log micro-batch data + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + + # Grad clip + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + + # Optimizer step + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step( + step=step, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + ) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + max_checkpoints=args.checkpoint.max_checkpoints, + ) + + step += 1 + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + if sampler is not None: + sampler.set_epoch(epoch) + + # Save final model + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_fsdp2( + model=model, + config=config, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + return float(perf_logger.min_loss.item()) + finally: + if perf_logger is not None: + perf_logger.finish() + else: + try: + debug_api.end_debug() + except RuntimeError: + pass + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main()