Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions bionemo-recipes/models/mixtral/modeling_mixtral_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
device = torch.cuda.current_device()
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
old_param = getattr(self, attr_name)
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))
if isinstance(old_param.data, DTensor):
# FSDP2 has sharded this param; materialize the local shard on CUDA
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
local_data = old_param.data.to_local()
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
new_dtensor = DTensor.from_local(
new_local,
device_mesh=old_param.data.device_mesh,
placements=old_param.data.placements,
)
setattr(self, attr_name, nn.Parameter(new_dtensor))
else:
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))

# Re-sync views to point to the new stacked parameter
self._sync_expert_views()
Expand All @@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w, DTensor):
gate_up_w = gate_up_w.to_local()
for i in range(self.num_local_experts):
num_local = gate_up_w.shape[0]
for i in range(num_local):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])

down_w = self.experts_down_weight
if isinstance(down_w, DTensor):
down_w = down_w.to_local()
for i in range(self.num_local_experts):
num_local_down = down_w.shape[0]
for i in range(num_local_down):
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])

def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:
Expand Down Expand Up @@ -865,6 +880,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
class HFInferenceParams(InferenceParams):
"""Extension of the InferenceParams class to support HF generate() and beam search."""

# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
# custom TE-based cache is not compatible with torch.compile generate().
is_compileable = False

def get_seq_length(self, layer_idx: int = 0) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,22 @@ def _restack_from_views(self) -> None:
device = torch.cuda.current_device()
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
old_param = getattr(self, attr_name)
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))
if isinstance(old_param.data, DTensor):
# FSDP2 has sharded this param; materialize the local shard on CUDA
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
local_data = old_param.data.to_local()
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
new_dtensor = DTensor.from_local(
new_local,
device_mesh=old_param.data.device_mesh,
placements=old_param.data.placements,
)
setattr(self, attr_name, nn.Parameter(new_dtensor))
else:
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))

# Re-sync views to point to the new stacked parameter
self._sync_expert_views()
Expand All @@ -304,13 +317,15 @@ def _sync_expert_views(self) -> None:
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w, DTensor):
gate_up_w = gate_up_w.to_local()
for i in range(self.num_local_experts):
num_local = gate_up_w.shape[0]
for i in range(num_local):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])

down_w = self.experts_down_weight
if isinstance(down_w, DTensor):
down_w = down_w.to_local()
for i in range(self.num_local_experts):
num_local_down = down_w.shape[0]
for i in range(num_local_down):
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])

def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:
Expand Down Expand Up @@ -871,6 +886,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
class HFInferenceParams(InferenceParams):
"""Extension of the InferenceParams class to support HF generate() and beam search."""

# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
# custom TE-based cache is not compatible with torch.compile generate().
is_compileable = False

def get_seq_length(self, layer_idx: int = 0) -> int:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared test utilities for distributed (EP/FSDP) tests."""

import os
import sys
from dataclasses import dataclass, field
from pathlib import Path

import torch


sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from modeling_mixtral_te import NVMixtralConfig


def create_small_mixtral_config(**overrides) -> NVMixtralConfig:
"""Create a small Mixtral config suitable for testing."""
defaults = {
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"num_local_experts": 4,
"num_experts_per_tok": 2,
"max_position_embeddings": 128,
"vocab_size": 1000,
"attn_input_format": "bshd",
"self_attn_mask_type": "causal",
"router_jitter_noise": 0.0,
}
defaults.update(overrides)
return NVMixtralConfig(**defaults)


def get_dummy_batch(vocab_size: int, seq_len: int = 32, batch_size: int = 2, device: str = "cuda"):
"""Create a simple dummy batch for testing."""
torch.manual_seed(42)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids)
labels = input_ids.clone()
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


@dataclass(frozen=True)
class DistributedConfig:
"""Distributed environment configuration."""

rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0")))
local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0")))
world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1")))
_master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost"))
_master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355"))

def is_main_process(self) -> bool:
"""Return True if this is the global rank 0 process."""
return self.rank == 0
Loading
Loading