From 6ddb3f792e906e6a53d40f081ee4061111e2535f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 10 Apr 2026 00:06:38 -0700 Subject: [PATCH 1/2] Revert "[model] fix: Use adaptive transpose_on_export for GPT-OSS expert weights (#3250)" This reverts commit b139a4bef97638b235596a3962dcd3ea4290319c. Signed-off-by: Yuki Huang --- .../hf_megatron_roundtrip_multi_gpu.py | 7 +---- .../bridge/models/conversion/auto_bridge.py | 6 ----- .../bridge/models/conversion/utils.py | 8 +----- src/megatron/bridge/utils/common_utils.py | 27 ------------------- 4 files changed, 2 insertions(+), 46 deletions(-) diff --git a/examples/conversion/hf_megatron_roundtrip_multi_gpu.py b/examples/conversion/hf_megatron_roundtrip_multi_gpu.py index fdd4f3df3f..80282daf53 100644 --- a/examples/conversion/hf_megatron_roundtrip_multi_gpu.py +++ b/examples/conversion/hf_megatron_roundtrip_multi_gpu.py @@ -53,7 +53,6 @@ from megatron.bridge import AutoBridge from megatron.bridge.models.decorators import torchrun_main from megatron.bridge.models.hf_pretrained.utils import is_safe_repo -from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type HF_MODEL_ID = "meta-llama/Llama-3.2-1B" @@ -189,11 +188,7 @@ def main( all_match = True fp8_skip_count = 0 fp8_skip_samples: list[str] = [] - # TODO: Remove fix_gpt_oss_export_transpose once GPT-OSS bridge export is fixed. - weight_iter = bridge.export_hf_weights(megatron_model, show_progress=False) - if get_hf_model_type(bridge) == "gpt_oss": - weight_iter = fix_gpt_oss_export_transpose(weight_iter) - for name, param in weight_iter: + for name, param in bridge.export_hf_weights(megatron_model, show_progress=False): if is_rank_0: original_param = bridge.hf_pretrained.state[name] compare_param = param diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index fde1a47de1..0c61b75d47 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -822,12 +822,6 @@ def _filter_quant(gen): generator = _filter_quant(generator) - # TODO: Remove once GPT-OSS bridge export no longer transposes per-expert weights. - from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type - - if get_hf_model_type(self) == "gpt_oss": - generator = fix_gpt_oss_export_transpose(generator) - # Check if the state source is SafeTensorsStateSource for streaming save. if ( hasattr(self.hf_pretrained, "state") diff --git a/src/megatron/bridge/models/conversion/utils.py b/src/megatron/bridge/models/conversion/utils.py index c7f4efd7d8..6e741bab0d 100644 --- a/src/megatron/bridge/models/conversion/utils.py +++ b/src/megatron/bridge/models/conversion/utils.py @@ -43,14 +43,8 @@ def weights_verification_table(bridge, megatron_model) -> Table: table.add_column("Device") table.add_column("Matches Original", justify="center") - # TODO: Remove fix_gpt_oss_export_transpose once GPT-OSS bridge export is fixed. - from megatron.bridge.utils.common_utils import fix_gpt_oss_export_transpose, get_hf_model_type - - weight_iter = bridge.export_hf_weights(megatron_model, show_progress=True) - if get_hf_model_type(bridge) == "gpt_oss": - weight_iter = fix_gpt_oss_export_transpose(weight_iter) # Check each weight against the original HF-model - for name, param in weight_iter: + for name, param in bridge.export_hf_weights(megatron_model, show_progress=True): original_param = bridge.hf_pretrained.state[name] table.add_row( name, diff --git a/src/megatron/bridge/utils/common_utils.py b/src/megatron/bridge/utils/common_utils.py index 77d3456cb4..35a3d507bd 100644 --- a/src/megatron/bridge/utils/common_utils.py +++ b/src/megatron/bridge/utils/common_utils.py @@ -251,33 +251,6 @@ def _wrapped(self, name, value): return module -def get_hf_model_type(bridge) -> str | None: - """Extract the HF ``model_type`` string from a bridge instance. - - Works with both ``AutoBridge`` (reads ``bridge.hf_pretrained.config.model_type``) - and registered bridge subclasses (falls back to the ``MODEL_TYPE`` class attribute). - """ - hf_config = getattr(getattr(bridge, "hf_pretrained", None), "config", None) - return getattr(hf_config, "model_type", None) or getattr(bridge, "MODEL_TYPE", None) - - -# TODO: Remove once GPT-OSS bridge export no longer transposes per-expert weights. -_GPT_OSS_TRANSPOSED_SUFFIXES = ("mlp.experts.down_proj",) - - -def fix_gpt_oss_export_transpose(gen): - """Wrap a weight generator to undo GPT-OSS per-expert transpose on export. - - The GPT-OSS bridge transposes down_proj expert weights in ``megatron_to_hf`` - for vLLM compatibility. This wrapper transposes them back so saved - checkpoints match the original HF layout. - """ - for name, tensor in gen: - if name.endswith(_GPT_OSS_TRANSPOSED_SUFFIXES): - tensor = tensor.transpose(-2, -1).contiguous() - yield name, tensor - - def extract_expert_number_from_param(param_name: str) -> int: """Extract the expert number from a parameter name. Args: From c4d5effa3c178543179982f7087e47a41b06677b Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 10 Apr 2026 06:56:45 -0700 Subject: [PATCH 2/2] remove transpose in down_proj Signed-off-by: Yuki Huang --- src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py index 40efb39a55..0039dca01a 100644 --- a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py +++ b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py @@ -106,9 +106,7 @@ def maybe_modify_loaded_hf_weight( ) -> torch.Tensor: """Load weights from HuggingFace state dict with MXFP4 dequantization support. - down_proj transpose on export is handled in GPTOSSMLPDownProjMapping.megatron_to_hf, - which transposes the per-expert weight from Megatron's [in, out] storage to - HF's expected [out, in] layout. + down_proj is handled in GPTOSSMLPDownProjMapping. gate_up_proj is handled directly in GPTOSSMLPGateUpProjMapping.hf_to_megatron via _align_expert_weight_to_shape, which auto-detects the orientation difference between @@ -212,11 +210,7 @@ def mapping_registry(self) -> MegatronMappingRegistry: class GPTOSSMLPDownProjMapping(AutoMapping): - """MLPDownProj for expert weights in GPT-OSS models. - - megatron_to_hf transposes the per-expert weight from Megatron's [in, out] - storage to HF's expected [out, in] layout. - """ + """MLPDownProj for expert weights in GPT-OSS models.""" is_grouped_export = True @@ -235,8 +229,6 @@ def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) - def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: if megatron_weights is None: return super().megatron_to_hf(megatron_weights, megatron_module) - if len(megatron_weights.shape) == 2: - megatron_weights = megatron_weights.transpose(0, 1) return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module)