Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions examples/conversion/hf_megatron_roundtrip_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from megatron.bridge import AutoBridge
from megatron.bridge.models.decorators import torchrun_main
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type


HF_MODEL_ID = "meta-llama/Llama-3.2-1B"
Expand Down Expand Up @@ -189,11 +188,7 @@ def main(
all_match = True
fp8_skip_count = 0
fp8_skip_samples: list[str] = []
# TODO: Remove fix_gpt_oss_export_transpose once GPT-OSS bridge export is fixed.
weight_iter = bridge.export_hf_weights(megatron_model, show_progress=False)
if get_hf_model_type(bridge) == "gpt_oss":
weight_iter = fix_gpt_oss_export_transpose(weight_iter)
for name, param in weight_iter:
for name, param in bridge.export_hf_weights(megatron_model, show_progress=False):
if is_rank_0:
original_param = bridge.hf_pretrained.state[name]
compare_param = param
Expand Down
6 changes: 0 additions & 6 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,12 +822,6 @@ def _filter_quant(gen):

generator = _filter_quant(generator)

# TODO: Remove once GPT-OSS bridge export no longer transposes per-expert weights.
from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type

if get_hf_model_type(self) == "gpt_oss":
generator = fix_gpt_oss_export_transpose(generator)

# Check if the state source is SafeTensorsStateSource for streaming save.
if (
hasattr(self.hf_pretrained, "state")
Expand Down
8 changes: 1 addition & 7 deletions src/megatron/bridge/models/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,8 @@ def weights_verification_table(bridge, megatron_model) -> Table:
table.add_column("Device")
table.add_column("Matches Original", justify="center")

# TODO: Remove fix_gpt_oss_export_transpose once GPT-OSS bridge export is fixed.
from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type

weight_iter = bridge.export_hf_weights(megatron_model, show_progress=True)
if get_hf_model_type(bridge) == "gpt_oss":
weight_iter = fix_gpt_oss_export_transpose(weight_iter)
# Check each weight against the original HF-model
for name, param in weight_iter:
for name, param in bridge.export_hf_weights(megatron_model, show_progress=True):
original_param = bridge.hf_pretrained.state[name]
table.add_row(
name,
Expand Down
12 changes: 2 additions & 10 deletions src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def maybe_modify_loaded_hf_weight(
) -> torch.Tensor:
"""Load weights from HuggingFace state dict with MXFP4 dequantization support.

down_proj transpose on export is handled in GPTOSSMLPDownProjMapping.megatron_to_hf,
which transposes the per-expert weight from Megatron's [in, out] storage to
HF's expected [out, in] layout.
down_proj is handled in GPTOSSMLPDownProjMapping.

gate_up_proj is handled directly in GPTOSSMLPGateUpProjMapping.hf_to_megatron via
_align_expert_weight_to_shape, which auto-detects the orientation difference between
Expand Down Expand Up @@ -212,11 +210,7 @@ def mapping_registry(self) -> MegatronMappingRegistry:


class GPTOSSMLPDownProjMapping(AutoMapping):
"""MLPDownProj for expert weights in GPT-OSS models.

megatron_to_hf transposes the per-expert weight from Megatron's [in, out]
storage to HF's expected [out, in] layout.
"""
"""MLPDownProj for expert weights in GPT-OSS models."""

is_grouped_export = True

Expand All @@ -235,8 +229,6 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -
def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]:
if megatron_weights is None:
return super().megatron_to_hf(megatron_weights, megatron_module)
if len(megatron_weights.shape) == 2:
megatron_weights = megatron_weights.transpose(0, 1)
return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module)


Expand Down
27 changes: 0 additions & 27 deletions src/megatron/bridge/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,33 +251,6 @@ def _wrapped(self, name, value):
return module


def get_hf_model_type(bridge) -> str | None:
"""Extract the HF ``model_type`` string from a bridge instance.

Works with both ``AutoBridge`` (reads ``bridge.hf_pretrained.config.model_type``)
and registered bridge subclasses (falls back to the ``MODEL_TYPE`` class attribute).
"""
hf_config = getattr(getattr(bridge, "hf_pretrained", None), "config", None)
return getattr(hf_config, "model_type", None) or getattr(bridge, "MODEL_TYPE", None)


# TODO: Remove once GPT-OSS bridge export no longer transposes per-expert weights.
_GPT_OSS_TRANSPOSED_SUFFIXES = ("mlp.experts.down_proj",)


def fix_gpt_oss_export_transpose(gen):
"""Wrap a weight generator to undo GPT-OSS per-expert transpose on export.

The GPT-OSS bridge transposes down_proj expert weights in ``megatron_to_hf``
for vLLM compatibility. This wrapper transposes them back so saved
checkpoints match the original HF layout.
"""
for name, tensor in gen:
if name.endswith(_GPT_OSS_TRANSPOSED_SUFFIXES):
tensor = tensor.transpose(-2, -1).contiguous()
yield name, tensor


def extract_expert_number_from_param(param_name: str) -> int:
"""Extract the expert number from a parameter name.
Args:
Expand Down
Loading