diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 8981d614843..e325e5346f1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax + # Per-block _amax (NVFP4 static) collapses the row axis we want + # to slice on; restore it so dim-0 slicing splits gate/up. + if amax.numel() != fused_total and amax.numel() % fused_total == 0: + amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] if fused_total % amax_dim0 == 0: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer.amax = amax[slice_start:slice_end].contiguous() + sliced = amax[slice_start:slice_end].contiguous() + # The amax setter refuses shape changes; drop _amax first. + if hasattr(w_quantizer, "_amax"): + delattr(w_quantizer, "_amax") + w_quantizer.amax = sliced else: warnings.warn( f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a58aa4c9895..73ae63a5a56 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: mod.revert_weight_conversion = original +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: + """Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set. + + Newer transformers reject ``do_sample=False`` mixed with sampling attrs in + ``save_pretrained``'s strict validate. + """ + gc = getattr(model, "generation_config", None) + if gc is None: + return + if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None: + gc.do_sample = True + + def export_speculative_decoding( model: torch.nn.Module, dtype: torch.dtype | None = None, @@ -1228,6 +1241,8 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() + _sanitize_generation_config_for_save(model) + try: model.save_pretrained( export_dir, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe4c3f77ce6..1a80c89e6c9 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -52,7 +52,6 @@ promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, - weight_attr_names, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -66,6 +65,127 @@ "svdquant", ] + +# Sibling groups that share an FP8 scale-of-scales: members feed the same input +# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would +# split their FP8 grids. +_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( + ("q_proj", "k_proj", "v_proj"), + ("gate_proj", "up_proj"), # Llama/Qwen/Mistral + ("w1", "w3"), # Mixtral +) + + +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" + groups: list[list[nn.Module]] = [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent in model.modules(): + for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: + members = [] + for n in sibling_names: + child = getattr(parent, n, None) + wq = getattr(child, wq_attr, None) if child is not None else None + if ( + isinstance(wq, TensorQuantizer) + and not wq._disabled + and wq.is_nvfp4_static + and getattr(wq, "_amax", None) is not None + ): + members.append(child) + if len(members) >= 2: + groups.append(members) + return groups + + +@torch.no_grad() +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: + """Populate ``_amax`` from weights for quantizers the forward pass didn't reach. + + Dead MoE experts that received no tokens are otherwise skipped by + ``mse_calibrate``, leaving export to derive separate per-half amax for + gate/up and break the gate==up ``weight_scale_2`` invariant. Weight access + runs inside ``enable_weight_access_and_writeback`` so FSDP / TP / offload + shards get gathered before calibration reads them. + """ + name_to_module = dict(model.named_modules()) + n = 0 + for module in name_to_module.values(): + if not isinstance(module, QuantModule): + continue + with enable_weight_access_and_writeback(module, model, name_to_module): + for weight, q in module.iter_weights_for_calibration(): + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: + continue + if q._calibrator is None: + continue + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): + continue + q.disable_quant() + q.enable_calib() + q(weight) + if q._calibrator.compute_amax() is not None: + q.load_calib_amax() + q.enable_quant() + q.disable_calib() + if hasattr(q._calibrator, "reset"): + q._calibrator.reset() + n += 1 + return n + + +@torch.no_grad() +def _sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. + + Reuses ``preprocess_linear_fusion`` (which performs the same unification at + export time) to keep the FP8 scale-of-scales consistent across siblings + during MSE / local-Hessian search. Must run after ``max_calibrate``. + + Sibling discovery is name-based via ``_GROUPED_WEIGHT_QUANTIZER_PATTERNS`` + (q/k/v_proj, gate/up_proj, w1/w3) — models with non-matching attribute + names (e.g. ``wqkv``, fused ``qkv_proj``, DeepSeek variants) silently fall + back to per-module global_amax. A warning is emitted when the model has + NVFP4-static weight quantizers but no groups were matched. + """ + # Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd + # from this module, so top-level would deadlock the cycle. + from modelopt.torch.export.quant_utils import preprocess_linear_fusion + + wq_attr = quantizer_attr_names("weight").weight_quantizer + n_groups = 0 + for group in _collect_grouped_linears(model): + for child in group: + wq = getattr(child, wq_attr) + if not isinstance(wq, NVFP4StaticQuantizer): + NVFP4StaticQuantizer.from_tensor_quantizer( + wq, global_amax=reduce_amax(wq._amax, axis=None) + ) + preprocess_linear_fusion(group) + n_groups += 1 + + if n_groups == 0: + # Surface architectures whose Q/K/V or gate/up siblings don't match the + # pattern list — without this, sibling-sync is a silent no-op. + has_nvfp4_static = any( + isinstance(m, TensorQuantizer) + and not m._disabled + and m.is_nvfp4_static + and getattr(m, "_amax", None) is not None + for m in model.modules() + ) + if has_nvfp4_static: + warnings.warn( + "_sync_grouped_weight_global_amax found NVFP4-static weight quantizers " + "but no Q/K/V or gate/up sibling groups matching " + "_GROUPED_WEIGHT_QUANTIZER_PATTERNS. Per-block FP8 grids will not be " + "unified across siblings; if this model uses non-standard projection " + "names (e.g. wqkv, fused qkv_proj), extend the pattern list.", + stacklevel=2, + ) + return n_groups + + CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -346,32 +466,25 @@ def mse_calibrate( See :class:`MseCalibConfig ` for details on the remaining arguments. """ - # Step 1: First get initial amax using max calibration + # Step 1: max calibration; then populate _amax for dead experts so step 3 + # doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up + # siblings so MSE searches against a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) + _bootstrap_uncalibrated_weight_quantizers(model) + _sync_grouped_weight_global_amax(model) - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - # and identify weight quantizers - weight_quantizers = [] - seen_modules = set() - + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() + is_nvfp4_static = module.is_nvfp4_static - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax + # _sync_grouped_weight_global_amax may have already promoted + + # unified global_amax across the sibling group; only promote + # standalone (non-grouped) NVFP4-static quantizers here. + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: @@ -412,52 +525,50 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) - # Identify weight quantizers by checking if they have corresponding weight parameters + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. + # The fused-experts override yields one pair per expert per projection, so + # every per-expert quantizer is MSE-calibrated (not just routed ones). name_to_module = dict(model.named_modules()) + seen_modules: set[int] = set() + pbar = tqdm(desc="MSE weight calibration") + n_calibrated = 0 for parent_module in name_to_module.values(): - if parent_module in seen_modules: + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): continue - for weight_name in weight_attr_names(parent_module): - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: - if getattr(weight_quantizer, "_calibrator", None) is not None: - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) - seen_modules.add(parent_module) - - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation - # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") - ): - # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() + seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): - weight = getattr(parent_module, weight_name) - weight_quantizer(weight) + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + if not ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and getattr(weight_quantizer, "_calibrator", None) is not None + ): + continue + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + weight_quantizer(weight) - # IMMEDIATELY compute amax and reset calibrator to free memory - cal = getattr(weight_quantizer, "_calibrator", None) - if cal is not None and cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete - # This is critical for multi-GPU setups where tensors may be on different devices - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - if cal is not None and hasattr(cal, "reset"): - cal.reset() + if hasattr(cal, "reset"): + cal.reset() - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + pbar.update(1) + n_calibrated += 1 + if n_calibrated % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + pbar.close() if torch.cuda.is_available(): for dev_id in range(torch.cuda.device_count()): @@ -612,6 +723,8 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) + _sync_grouped_weight_global_amax(model) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -666,14 +779,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_static = ( - weight_quantizer.is_static_block_quant - and weight_quantizer._num_bits == (2, 1) - and weight_quantizer._block_sizes is not None - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = weight_quantizer.is_nvfp4_static - if is_nvfp4_static: + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..fa540b8fdf5 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -514,6 +514,16 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_nvfp4_static(self): + """True for E2M1 weights + E4M3 per-block scales in static layout (format-only check).""" + return ( + self.is_static_block_quant + and self._num_bits == (2, 1) + and self._block_sizes is not None + and self._block_sizes.get("scale_bits") == (4, 3) + ) + def is_mxfp(self, bits): """Check if is MXFP4/MXFP6/MXFP8.""" if bits == 4: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 77f26b20602..1873ecda528 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield ``(weight_slice, quantizer)`` per-expert pairs. + + The base impl uses singular ``*_weight_quantizer`` and skips fused- + experts modules, so weight-only calibration never reaches per-expert + quantizers without this override. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + yield weight[idx], q + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..cea3d4260e4 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: + if module.is_nvfp4_static: initial_amax = module._amax.clone().detach() global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index e0ce2f0c66e..19e1ed49197 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -388,6 +388,110 @@ def _spy_export(wrapper, dtype): if QuantModuleRegistry.get(expert_type) is not None: QuantModuleRegistry.unregister(expert_type) + def test_per_block_amax_reshape_for_fused_export(self, monkeypatch): + """Per-block ``_amax`` (NVFP4 static, row axis collapsed) must be reshaped + before dim-0 slicing so gate's blocks and up's blocks are split correctly. + + Regression for the bug where a flat per-block ``_amax`` of shape + ``(fused_total * blocks_per_row,)`` was sliced naively, producing wrong + per-projection scales. The fix reshapes to ``(fused_total, blocks_per_row)`` + before slicing on dim-0 when ``amax.numel() % fused_total == 0``. + """ + from modelopt.torch.export.moe_utils import _export_fused_experts + + experts = _SyntheticFusedExperts() + expert_type = type(experts) + if QuantModuleRegistry.get(expert_type) is None: + QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})( + _QuantFusedExperts + ) + try: + converted = QuantModuleRegistry.convert(experts) + + # Per-block amax: 4 blocks per row. Distinct values per row so we can + # detect whether the reshape correctly preserves the row→block layout. + blocks_per_row = 4 + fused_total = 2 * INTERMEDIATE_DIM # gate_up rows + for idx in range(NUM_EXPERTS): + # Gate rows take values 1..INTERMEDIATE_DIM, up rows 101..101+INTERMEDIATE_DIM. + gate_amax = ( + torch.arange(1, INTERMEDIATE_DIM + 1).float().repeat_interleave(blocks_per_row) + ) + up_amax = ( + torch.arange(101, 101 + INTERMEDIATE_DIM) + .float() + .repeat_interleave(blocks_per_row) + ) + # Flat shape (fused_total * blocks_per_row,) — row axis collapsed. + flat = torch.cat([gate_amax, up_amax]) + assert flat.numel() == fused_total * blocks_per_row + + wq = converted.gate_up_proj_weight_quantizers[idx] + wq._disabled = False + wq.amax = flat + + # down_proj quantizers also need to look calibrated (otherwise + # the export-time fallback would compute amax from each weight + # slice and we'd skip the new reshape branch). Set a 1-D per-row + # amax that matches dim-0 of down_proj (so amax.numel() == fused_total + # for down). That intentionally does NOT exercise the new branch + # for down — we only want to exercise it for gate_up. + dwq = converted.down_proj_weight_quantizers[idx] + dwq._disabled = False + dwq.amax = torch.ones(HIDDEN_DIM) + + seen = {} + + def _spy_export(wrapper, dtype): + w = wrapper.weight.data + wq = wrapper.weight_quantizer + amax = wq._amax.detach().clone() if hasattr(wq, "_amax") else None + for idx in range(NUM_EXPERTS): + g_slice = converted.gate_up_proj.data[idx, :INTERMEDIATE_DIM, :] + u_slice = converted.gate_up_proj.data[idx, INTERMEDIATE_DIM:, :] + if w.shape == g_slice.shape and torch.equal(w, g_slice): + seen[(idx, "gate_proj")] = amax + return + if w.shape == u_slice.shape and torch.equal(w, u_slice): + seen[(idx, "up_proj")] = amax + return + + monkeypatch.setattr( + "modelopt.torch.export.unified_export_hf._export_quantized_weight", + _spy_export, + ) + + _export_fused_experts(converted, torch.float16) + + # gate's amax should contain values 1..INTERMEDIATE_DIM repeated + # blocks_per_row times, reshaped to (INTERMEDIATE_DIM, blocks_per_row); + # up's amax should contain 101..101+INTERMEDIATE_DIM same shape. + for idx in range(NUM_EXPERTS): + g_amax = seen.get((idx, "gate_proj")) + u_amax = seen.get((idx, "up_proj")) + assert g_amax is not None and u_amax is not None, ( + f"Expert {idx}: missing recorded amax" + ) + assert g_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} gate amax dim-0 should be {INTERMEDIATE_DIM} " + f"after reshape+slice, got {g_amax.shape}" + ) + assert u_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} up amax dim-0 should be {INTERMEDIATE_DIM}, got {u_amax.shape}" + ) + # First block of first row carries the marker value. + assert g_amax.flatten()[0].item() == 1.0, ( + f"Expert {idx} gate amax[0,0] should be 1.0 (gate row 0 marker), " + f"got {g_amax.flatten()[0].item()} — reshape probably didn't restore row axis" + ) + assert u_amax.flatten()[0].item() == 101.0, ( + f"Expert {idx} up amax[0,0] should be 101.0 (up row 0 marker), " + f"got {u_amax.flatten()[0].item()} — slice probably didn't separate gate from up" + ) + finally: + if QuantModuleRegistry.get(expert_type) is not None: + QuantModuleRegistry.unregister(expert_type) + # --------------------------------------------------------------------------- # Tests for force_eager_experts_impl_on_the_fly @@ -529,6 +633,101 @@ def forward_loop(m): self._cleanup_registry(expert_type) + def test_bootstrap_populates_dead_expert_quantizers(self): + """`_bootstrap_uncalibrated_weight_quantizers` fills `_amax` on experts the + forward pass never routed to. + + Regression for the dead-expert MSE skip: with partial routing during max + calibration, never-routed experts' weight quantizers stay with + ``_amax=None``; bootstrap must run the calibrator on the per-expert weight + slice (via ``iter_weights_for_calibration``) to populate them so MSE + doesn't skip them. + """ + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import ( + _bootstrap_uncalibrated_weight_quantizers, + ) + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + # Forward loop that routes only to experts 0 and 1 (deterministic). + # Bypasses the router and calls experts directly with crafted indices. + live = {0, 1} + dead = {idx for idx in range(NUM_EXPERTS) if idx not in live} + assert dead, "Test requires at least one dead expert" + + def partial_forward(m): + torch.manual_seed(0) + seq_len = 8 + hidden = torch.randn(seq_len, HIDDEN_DIM) + top_k_index = torch.zeros(seq_len, TOP_K, dtype=torch.long) + top_k_index[:, 0] = 0 + top_k_index[:, 1] = 1 + top_k_weights = torch.ones(seq_len, TOP_K) / TOP_K + with torch.no_grad(): + m.moe.experts(hidden, top_k_index, top_k_weights) + + mtq.quantize(model, quant_cfg, forward_loop=partial_forward) + + experts = model.moe.experts + + # Pre-bootstrap: dead experts have no/zero _amax. + for idx in dead: + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert getattr(gu_q, "_amax", None) is None or torch.all(gu_q._amax == 0), ( + f"Dead expert {idx} gate_up_proj should be uncalibrated pre-bootstrap" + ) + assert getattr(d_q, "_amax", None) is None or torch.all(d_q._amax == 0), ( + f"Dead expert {idx} down_proj should be uncalibrated pre-bootstrap" + ) + + n_bootstrapped = _bootstrap_uncalibrated_weight_quantizers(model) + assert n_bootstrapped >= 2 * len(dead), ( + f"Expected ≥{2 * len(dead)} bootstrapped (gate_up + down per dead expert), " + f"got {n_bootstrapped}" + ) + + # Post-bootstrap: every expert has populated _amax matching max(|weight|). + for idx in range(NUM_EXPERTS): + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert gu_q._amax is not None and not torch.all(gu_q._amax == 0), ( + f"Expert {idx} gate_up_proj _amax not populated after bootstrap" + ) + assert d_q._amax is not None and not torch.all(d_q._amax == 0), ( + f"Expert {idx} down_proj _amax not populated after bootstrap" + ) + + # For dead experts, bootstrap reads max(|weight|). Sanity-check it matches + # the actual weight tensor's per-row max (axis=0 reduces over hidden_dim). + for idx in dead: + expected = experts.gate_up_proj.data[idx].abs().amax(dim=1) + got = experts.gate_up_proj_weight_quantizers[idx]._amax.flatten() + assert torch.allclose(got, expected, atol=1e-4), ( + f"Expert {idx} bootstrap amax should equal per-row max(|weight|); " + f"max diff {(got - expected).abs().max().item()}" + ) + + self._cleanup_registry(expert_type) + # --------------------------------------------------------------------------- # Tests for export enumeration — guards the bug where fused-experts were