Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
and w_quantizer._amax.dim() >= 1
):
amax = w_quantizer._amax
# Per-block _amax (NVFP4 static) collapses the row axis we want
# to slice on; restore it so dim-0 slicing splits gate/up.
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
amax_dim0 = amax.shape[0]
if fused_total % amax_dim0 == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
sliced = amax[slice_start:slice_end].contiguous()
# The amax setter refuses shape changes; drop _amax first.
if hasattr(w_quantizer, "_amax"):
delattr(w_quantizer, "_amax")
w_quantizer.amax = sliced
else:
warnings.warn(
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand Down
15 changes: 15 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
mod.revert_weight_conversion = original


def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
"""Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set.

Newer transformers reject ``do_sample=False`` mixed with sampling attrs in
``save_pretrained``'s strict validate.
"""
gc = getattr(model, "generation_config", None)
if gc is None:
return
if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None:
gc.do_sample = True

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1228,6 +1241,8 @@ def export_hf_checkpoint(
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

_sanitize_generation_config_for_save(model)

try:
model.save_pretrained(
export_dir,
Expand Down
234 changes: 171 additions & 63 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
promote_nvfp4_static_quantizers,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper

Expand All @@ -66,6 +65,127 @@
"svdquant",
]


# Sibling groups that share an FP8 scale-of-scales: members feed the same input
# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would
# split their FP8 grids.
_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = (
("q_proj", "k_proj", "v_proj"),
("gate_proj", "up_proj"), # Llama/Qwen/Mistral
("w1", "w3"), # Mixtral
)


def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
groups: list[list[nn.Module]] = []
wq_attr = quantizer_attr_names("weight").weight_quantizer
for parent in model.modules():
for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS:
members = []
for n in sibling_names:
child = getattr(parent, n, None)
wq = getattr(child, wq_attr, None) if child is not None else None
if (
isinstance(wq, TensorQuantizer)
and not wq._disabled
and wq.is_nvfp4_static
and getattr(wq, "_amax", None) is not None
):
members.append(child)
if len(members) >= 2:
groups.append(members)
return groups


@torch.no_grad()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Populate ``_amax`` from weights for quantizers the forward pass didn't reach.

Dead MoE experts that received no tokens are otherwise skipped by
``mse_calibrate``, leaving export to derive separate per-half amax for
gate/up and break the gate==up ``weight_scale_2`` invariant. Weight access
runs inside ``enable_weight_access_and_writeback`` so FSDP / TP / offload
shards get gathered before calibration reads them.
"""
name_to_module = dict(model.named_modules())
n = 0
for module in name_to_module.values():
if not isinstance(module, QuantModule):
continue
with enable_weight_access_and_writeback(module, model, name_to_module):
for weight, q in module.iter_weights_for_calibration():
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
continue
if q._calibrator is None:
continue
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
continue
q.disable_quant()
q.enable_calib()
q(weight)
if q._calibrator.compute_amax() is not None:
q.load_calib_amax()
q.enable_quant()
q.disable_calib()
if hasattr(q._calibrator, "reset"):
q._calibrator.reset()
n += 1
return n
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@torch.no_grad()
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.

Reuses ``preprocess_linear_fusion`` (which performs the same unification at
export time) to keep the FP8 scale-of-scales consistent across siblings
during MSE / local-Hessian search. Must run after ``max_calibrate``.

Sibling discovery is name-based via ``_GROUPED_WEIGHT_QUANTIZER_PATTERNS``
(q/k/v_proj, gate/up_proj, w1/w3) — models with non-matching attribute
names (e.g. ``wqkv``, fused ``qkv_proj``, DeepSeek variants) silently fall
back to per-module global_amax. A warning is emitted when the model has
NVFP4-static weight quantizers but no groups were matched.
"""
# Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd
# from this module, so top-level would deadlock the cycle.
from modelopt.torch.export.quant_utils import preprocess_linear_fusion

wq_attr = quantizer_attr_names("weight").weight_quantizer
n_groups = 0
for group in _collect_grouped_linears(model):
for child in group:
wq = getattr(child, wq_attr)
if not isinstance(wq, NVFP4StaticQuantizer):
NVFP4StaticQuantizer.from_tensor_quantizer(
wq, global_amax=reduce_amax(wq._amax, axis=None)
)
preprocess_linear_fusion(group)
n_groups += 1

if n_groups == 0:
# Surface architectures whose Q/K/V or gate/up siblings don't match the
# pattern list — without this, sibling-sync is a silent no-op.
has_nvfp4_static = any(
isinstance(m, TensorQuantizer)
and not m._disabled
and m.is_nvfp4_static
and getattr(m, "_amax", None) is not None
for m in model.modules()
)
if has_nvfp4_static:
warnings.warn(
"_sync_grouped_weight_global_amax found NVFP4-static weight quantizers "
"but no Q/K/V or gate/up sibling groups matching "
"_GROUPED_WEIGHT_QUANTIZER_PATTERNS. Per-block FP8 grids will not be "
"unified across siblings; if this model uses non-standard projection "
"names (e.g. wqkv, fused qkv_proj), extend the pattern list.",
stacklevel=2,
)
Comment on lines +167 to +185
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Avoid warning on models that have nothing to synchronize.

This warns whenever any calibrated NVFP4-static quantizer exists and no group matched. That also catches models with only standalone NVFP4 linears, so mse_calibrate() / local_hessian_calibrate() will emit a misleading warning even though there are no Q/K/V or gate/up siblings to unify. Please gate this on the presence of candidate sibling attributes rather than on has_nvfp4_static alone.

Possible guard
-    if n_groups == 0:
+    if n_groups == 0:
+        has_candidate_siblings = any(
+            sum(1 for name in sibling_names if hasattr(parent, name)) >= 2
+            for parent in model.modules()
+            for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS
+        )
         # Surface architectures whose Q/K/V or gate/up siblings don't match the
         # pattern list — without this, sibling-sync is a silent no-op.
         has_nvfp4_static = any(
             isinstance(m, TensorQuantizer)
             and not m._disabled
             and m.is_nvfp4_static
             and getattr(m, "_amax", None) is not None
             for m in model.modules()
         )
-        if has_nvfp4_static:
+        if has_candidate_siblings and has_nvfp4_static:
             warnings.warn(
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 167 - 185, The
warning in _sync_grouped_weight_global_amax currently fires whenever any
NVFP4-static TensorQuantizer exists but no groups matched; change this to first
check whether the model contains any potential sibling candidates before
warning: scan model.modules() for modules whose names or attributes match the
grouping heuristics (i.e. the same criteria used by
_GROUPED_WEIGHT_QUANTIZER_PATTERNS or the sibling attribute tests used when
building groups) and only if such candidate siblings exist and n_groups == 0 and
has_nvfp4_static is true, emit the warnings.warn; update the logic around
n_groups, has_nvfp4_static, and the invocation of
_GROUPED_WEIGHT_QUANTIZER_PATTERNS in _sync_grouped_weight_global_amax to
reflect this guarded check.

return n_groups


CalibratorFactory: TypeAlias = Callable[
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
]
Expand Down Expand Up @@ -346,32 +466,25 @@ def mse_calibrate(
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
"""
# Step 1: First get initial amax using max calibration
# Step 1: max calibration; then populate _amax for dead experts so step 3
# doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up
# siblings so MSE searches against a consistent FP8 grid.
max_calibrate(model, forward_loop, distributed_sync)
_bootstrap_uncalibrated_weight_quantizers(model)
_sync_grouped_weight_global_amax(model)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
is_nvfp4_static = module.is_nvfp4_static

is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)

if is_nvfp4_static:
# Compute and set global_amax
# _sync_grouped_weight_global_amax may have already promoted +
# unified global_amax across the sibling group; only promote
# standalone (non-grouped) NVFP4-static quantizers here.
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)

# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep:
Expand Down Expand Up @@ -412,52 +525,50 @@ def mse_calibrate(
quant_func=partial(_mse_quant_func, quantizer=module),
)

# Identify weight quantizers by checking if they have corresponding weight parameters
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
# The fused-experts override yields one pair per expert per projection, so
# every per-expert quantizer is MSE-calibrated (not just routed ones).
name_to_module = dict(model.named_modules())
seen_modules: set[int] = set()
pbar = tqdm(desc="MSE weight calibration")
n_calibrated = 0
for parent_module in name_to_module.values():
if parent_module in seen_modules:
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
# This prevents massive memory accumulation seen in large models
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
tqdm(weight_quantizers, desc="MSE weight calibration")
):
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
seen_modules.add(id(parent_module))
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
if not (
isinstance(weight_quantizer, TensorQuantizer)
and weight_quantizer.is_enabled
and getattr(weight_quantizer, "_calibrator", None) is not None
):
continue
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
weight_quantizer(weight)

# IMMEDIATELY compute amax and reset calibrator to free memory
cal = getattr(weight_quantizer, "_calibrator", None)
if cal is not None and cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()
cal = weight_quantizer._calibrator
if cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()

weight_quantizer.enable_quant()
weight_quantizer.disable_calib()
weight_quantizer.enable_quant()
weight_quantizer.disable_calib()

# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
# This is critical for multi-GPU setups where tensors may be on different devices
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))

if cal is not None and hasattr(cal, "reset"):
cal.reset()
if hasattr(cal, "reset"):
cal.reset()

if (idx + 1) % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.update(1)
n_calibrated += 1
if n_calibrated % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.close()

if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -612,6 +723,8 @@ def forward(self, input, *args, **kwargs):
print_rank_0("local_hessian: Running max calibration for all quantizers...")
max_calibrate(model, forward_loop, distributed_sync)

_sync_grouped_weight_global_amax(model)

# Setup helpers for all quantized linear modules
name_to_module = dict(model.named_modules())
weight_quantizers_info = []
Expand Down Expand Up @@ -666,14 +779,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):

return xq

is_nvfp4_static = (
weight_quantizer.is_static_block_quant
and weight_quantizer._num_bits == (2, 1)
and weight_quantizer._block_sizes is not None
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
)
is_nvfp4_static = weight_quantizer.is_nvfp4_static

if is_nvfp4_static:
if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)

Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,16 @@ def is_mx_format(self):
and self.block_sizes.get("scale_bits", None) == (8, 0)
)

@property
def is_nvfp4_static(self):
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
return (
self.is_static_block_quant
and self._num_bits == (2, 1)
and self._block_sizes is not None
and self._block_sizes.get("scale_bits") == (4, 3)
)

def is_mxfp(self, bits):
"""Check if is MXFP4/MXFP6/MXFP8."""
if bits == 4:
Expand Down
18 changes: 18 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
self._down_proj_linear = False
return super().forward(*args, **kwargs)

def iter_weights_for_calibration(self):
"""Yield ``(weight_slice, quantizer)`` per-expert pairs.

The base impl uses singular ``*_weight_quantizer`` and skips fused-
experts modules, so weight-only calibration never reaches per-expert
quantizers without this override.
"""
for weight_name, quantizers_name in (
("gate_up_proj", "gate_up_proj_weight_quantizers"),
("down_proj", "down_proj_weight_quantizers"),
):
weight = getattr(self, weight_name, None)
quantizers = getattr(self, quantizers_name, None)
if weight is None or quantizers is None:
continue
for idx, q in enumerate(quantizers):
yield weight[idx], q

def fold_weight(self, keep_attrs: bool = False):
"""Fold per-expert weight quantizers into the fused 3-D weights.

Expand Down
Loading
Loading