Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions examples/ascend_examples/qwen3_4B_dpo_megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
defaults:
- ../config/deepspeed_zero@_here_
- ../config/deepspeed_zero2@_here_
- ../config/deepspeed_zero3@_here_
- ../config/deepspeed_zero3_cpuoffload@_here_

hydra:
run:
dir: .
output_subdir: null

exp_name: "qwen3-4B-dpo-config"
seed: 42
logging_dir: ./output/logs
output_dir: ./output
system_envs:
USE_MODELSCOPE: '1'

checkpoint_config:
type: file_system
output_dir: ./ckpt


track_name: None


max_steps: 500
save_steps: 500
logging_steps: 1
eval_steps: 100
resume_from_checkpoint: false

sequence_length: 512
train_batch_size: 64
val_batch_size: 64

# local_rank: -1
num_nodes: 1
num_gpus_per_node: 4

pretrain: Qwen/Qwen3-4B

ipo: false
beta: 0.1
label_smoothing: 0.0

chosen_key: chosen
rejected_key: rejected

validation:
data_args:
template: qwen3
file_name: data/comparison_gpt4_data_zh.json

actor_train:
model_args:
disable_gradient_checkpointing: false
dtype: bf16
model_type: ~
training_args:
lr_scheduler_type: constant
learning_rate: 1.0e-6
weight_decay: 0
per_device_train_batch_size: 16
gradient_accumulation_steps: 1
warmup_steps: 20
num_train_epochs: 10
data_args:
template: qwen3
file_name:
- data/comparison_gpt4_data_zh.json
dataset_dir: data
preprocessing_num_workers: 1
strategy_args:
strategy_name: megatron_train
strategy_config:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
expert_model_parallel_size: 1
use_distributed_optimizer: true
recompute_granularity: full
device_mapping: list(range(0,2))
infer_batch_size: 16


reference:
model_args:
disable_gradient_checkpointing: true
dtype: bf16
model_type: ~
data_args:
template: qwen3
strategy_args:
strategy_name: megatron_infer
strategy_config:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
expert_model_parallel_size: 1
device_mapping: list(range(2,4))
infer_batch_size: 16
5 changes: 5 additions & 0 deletions examples/ascend_examples/run_dpo_pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
set +x

CONFIG_PATH=$(basename $(dirname $0))
python examples/start_dpo_pipeline.py --config_path $CONFIG_PATH --config_name qwen3_4B_dpo_megatron
44 changes: 37 additions & 7 deletions mcore_adapter/src/mcore_adapter/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,53 @@
logger = get_logger(__name__)


def _patch_megatron_for_npu():
if not current_platform.is_npu():
return

import torch_npu # noqa: F401

import megatron.core.tensor_parallel.random as meg_random

if not hasattr(meg_random, "_npu_patched"):
meg_random.initialize_rng_tracker()

def patched_set(new_state, device=-1, graph_safe=False):
torch.npu.set_rng_state(new_state)
return

def patched_get(device="npu", clone=False, graph_safe=False):
return torch.npu.get_rng_state()

meg_random._set_cuda_rng_state = patched_set
meg_random._get_cuda_rng_state = patched_get

rng_state = torch.npu.get_rng_state()
meg_random._CUDA_RNG_STATE_TRACKER.states_["model-parallel-rng"] = rng_state
meg_random._CUDA_RNG_STATE_TRACKER.states_["data-parallel-rng"] = rng_state

meg_random._npu_patched = True

if not hasattr(torch.cuda, "_npu_patched"):
_original_cuda_current_device = torch.cuda.current_device
torch.cuda.current_device = lambda: torch.npu.current_device()
torch.cuda._npu_patched = True


def is_distribute_initialized():
return mpu.model_parallel_is_initialized()


def _set_random_seed(seed_):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
seed = seed_ # TuningFactory dataloader requires seed be the same for all ranks
# # Ensure that different pipeline MP stages get different seeds.
# seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# # Ensure different data parallel ranks get different seeds
# if data_parallel_random_init:
# seed = seed + (10 * mpu.get_data_parallel_rank())
seed = seed_
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if current_platform.device_count() > 0:
if current_platform.is_npu():
_patch_megatron_for_npu()
elif current_platform.is_cuda() and current_platform.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed))
Expand Down
4 changes: 4 additions & 0 deletions mcore_adapter/src/mcore_adapter/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..constants import HUGGINGFACE_AUTOMAP_CACHE, MCA_CONFIG_NAME
from ..initialize import initialize_megatron
from ..platforms import current_platform
from ..training_args import DistributingParallelArguments, TrainingArguments
from ..utils import get_logger
from .converter.template import get_template
Expand Down Expand Up @@ -297,6 +298,9 @@ class McaModelConfig(TransformerConfig, PretrainedConfig):
)

def __post_init__(self):
if current_platform.is_npu() and self.transformer_impl == "transformer_engine":
self.transformer_impl = "local"

if self.virtual_pipeline_model_parallel_size is None and self.overlap_p2p_comm:
self.overlap_p2p_comm = False
logger.warning("Non-interleaved pipeline parallelism does not support overlapping p2p communication!")
Expand Down
46 changes: 40 additions & 6 deletions mcore_adapter/src/mcore_adapter/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
logger = get_logger(__name__)


def _replace_with_rmsnorm(submodules, attr_name):
if hasattr(submodules, attr_name):
norm = getattr(submodules, attr_name)
if isinstance(norm, type):
norm_name = norm.__name__
else:
norm_name = norm.__class__.__name__
if norm_name in ("TENorm", "FusedLayerNorm") or not norm_name.endswith("RMSNorm"):
setattr(submodules, attr_name, RMSNorm)


class VirtualModels:
# a wrapper for model list to support virtual pipeline model parallel
def __init__(self, cls, config: "McaModelConfig", *args, **kwargs):
Expand Down Expand Up @@ -369,8 +380,13 @@ def _get_transformer_layer_spec(self, config: Optional["McaModelConfig"] = None)
transformer_block_spec.layer_norm = RMSNorm
for transformer_layer_spec in transformer_block_spec.layer_specs:
if not use_te and config.normalization == "RMSNorm":
transformer_layer_spec.submodules.input_layernorm = RMSNorm
transformer_layer_spec.submodules.pre_mlp_layernorm = RMSNorm
input_layernorm = transformer_layer_spec.submodules.input_layernorm
if current_platform.is_npu() and not input_layernorm.__class__.__name__.endswith("RMSNorm"):
transformer_layer_spec.submodules.input_layernorm = RMSNorm
transformer_layer_spec.submodules.pre_mlp_layernorm = RMSNorm
elif not current_platform.is_npu():
transformer_layer_spec.submodules.input_layernorm = RMSNorm
transformer_layer_spec.submodules.pre_mlp_layernorm = RMSNorm
if getattr(transformer_layer_spec.submodules.mlp.submodules, "shared_experts", None):
transformer_layer_spec.submodules.mlp.submodules.shared_experts.params["gate"] = (
config.moe_use_shared_expert_gate
Expand All @@ -381,10 +397,28 @@ def _get_transformer_layer_spec(self, config: Optional["McaModelConfig"] = None)
config.num_moe_experts, config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)
else:
module_spec = get_gpt_layer_local_spec(
config.num_moe_experts, config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)
if config.normalization == "RMSNorm":
if current_platform.is_npu():
module_spec = get_gpt_layer_local_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
normalization=config.normalization,
)
else:
module_spec = get_gpt_layer_local_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
)
if not use_te and config.normalization == "RMSNorm" and current_platform.is_npu():
module_spec.layer_norm = RMSNorm
_replace_with_rmsnorm(module_spec.submodules, "input_layernorm")
_replace_with_rmsnorm(module_spec.submodules, "pre_mlp_layernorm")
self_attn = module_spec.submodules.self_attention
if hasattr(self_attn, "submodules"):
_replace_with_rmsnorm(self_attn.submodules, "q_layernorm")
_replace_with_rmsnorm(self_attn.submodules, "k_layernorm")
elif not use_te and config.normalization == "RMSNorm":
module_spec.submodules.input_layernorm = RMSNorm
module_spec.submodules.pre_mlp_layernorm = RMSNorm
return module_spec
Expand Down
7 changes: 7 additions & 0 deletions mcore_adapter/src/mcore_adapter/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
from dataclasses import dataclass, field, fields
from typing import Literal, Optional, Union

try:
# NPU patch
import flashinfer
import mindspeed.megatron_adaptor
Comment thread
HuangJoJo marked this conversation as resolved.
except ImportError:
pass

from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout
from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments
from transformers import TrainingArguments as HFTrainingArguments
Expand Down
25 changes: 25 additions & 0 deletions roll/distributed/strategy/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,13 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode
else:
input_ids = self._get_feature_on_this_cp_rank(input_ids, "input_ids")
attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask")

if hasattr(torch, "npu") and torch.npu.is_available() and attention_mask is not None:
attention_mask = attention_mask.bool()
B, S = attention_mask.shape
attention_mask = attention_mask[:, None, None, :] # [B,1,1,S]
attention_mask = attention_mask.expand(B, 1, S, S) # [B,1,S,S]

if labels is not None:
labels = self._get_feature_on_this_cp_rank(labels, "labels")
position_ids = None
Expand Down Expand Up @@ -1132,6 +1139,24 @@ def train_step(self, batch: DataProto, loss_func: Callable):
# 只有step的时候需要load optimizer states
self.load_states(include=[OffloadStateType.optimizer_states])

# Ensure FP32 main params are on the correct device (they may be on CPU after
# the offload/reload cycle on certain platforms like Ascend NPU).
if current_platform.is_npu():
optimizers = (
self.optimizer.chained_optimizers
if hasattr(self.optimizer, 'chained_optimizers')
else [self.optimizer]
)
expected_device = torch.device(
f'{current_platform.device_type}:{current_platform.current_device()}'
)
for opt in optimizers:
for param_groups_attr in ('shard_fp32_from_float16_groups', 'shard_fp32_groups'):
for group in getattr(opt, param_groups_attr, []):
for param in group:
if param.device != expected_device:
param.data = param.data.to(expected_device, non_blocking=False)

update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step()
if is_offload_optimizer_states_in_train_step:
self.offload_states(include=[OffloadStateType.optimizer_states], non_blocking=True)
Expand Down
5 changes: 2 additions & 3 deletions roll/pipeline/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,7 @@ async def offload_states_partial(self, target_dp_ranks: List[int]):

# Verify offloaded workers have near-zero GPU memory usage
if self.rank_info.dp_rank in target_dp_ranks:
import torch
gpu_memory_gb = torch.cuda.memory_allocated() / 1024**3
gpu_memory_gb = current_platform.memory_allocated() / 1024**3
if gpu_memory_gb > 1.0:
raise RuntimeError(
f"GPU memory not properly offloaded for Worker {self.rank} (DP {self.rank_info.dp_rank}): "
Expand Down Expand Up @@ -501,7 +500,7 @@ async def generate(self, data: DataProto):
global_step = data.meta_info.get("global_step", 0)
self.logger.info(f"{self.worker_name} generate global step {global_step}")

data = data.to("cuda")
data = data.to(current_platform.device_type)
data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size

output = await self.strategy.generate(batch=data, generation_config=generation_config)
Expand Down
21 changes: 13 additions & 8 deletions roll/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,31 @@ def _init_platform() -> Platform:
Returns:
An instance of a subclass of Platform corresponding to the detected hardware.
"""
try:
import torch_npu # noqa: F401

if hasattr(torch, "npu") and torch.npu.is_available():
logger.debug("Detected NPU (torch_npu). Initializing NPU platform.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz change log_level to info

return NpuPlatform()
except ImportError:
pass

if torch.cuda.is_available():
device_name = torch.cuda.get_device_name().upper()
logger.debug(f"Detected CUDA device: {device_name}")

if "NVIDIA" in device_name:
logger.debug("Initializing CUDA platform (NVIDIA).")
return CudaPlatform()
elif "AMD" in device_name:
logger.debug("Initializing ROCm platform (AMD).")
return RocmPlatform()

logger.warning("Unrecognized CUDA device. Falling back to UnknownPlatform.")
return UnknownPlatform()
else:
try:
import torch_npu # noqa: F401

logger.debug("Detected torch_npu. Initializing NPU platform.")
return NpuPlatform()
except ImportError:
logger.debug("No supported accelerator detected. Initializing CPU platform.")
return CpuPlatform()
logger.debug("No supported accelerator detected. Initializing CPU platform.")
return CpuPlatform()


# Global singleton representing the current platform in use.
Expand Down
10 changes: 8 additions & 2 deletions roll/third_party/megatron/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ def get_megatron_optimizer(
optimizers = []
model_chunk_offset = 0
kwargs = {}
if "config_overrides" in inspect.signature(_get_param_groups_and_buffers).parameters:
_param_groups_sig = inspect.signature(_get_param_groups_and_buffers).parameters
if "config_overrides" in _param_groups_sig:
# config_overrides is required in mcore-core>=0.16
kwargs = {"config_overrides": None}
kwargs["config_overrides"] = None
if "no_weight_decay_cond" in _param_groups_sig:
# no_weight_decay_cond, scale_lr_cond, lr_mult are required in newer mcore versions
kwargs["no_weight_decay_cond"] = no_weight_decay_cond
kwargs["scale_lr_cond"] = scale_lr_cond
kwargs["lr_mult"] = lr_mult
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
Expand Down
Loading