Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ def __init__(
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
fp8_scale_sweep_stride: int = 1,
):
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
self._global_amax = global_amax
self._fp8_scale_sweep_stride = max(1, fp8_scale_sweep_stride or 1)

def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
if candidates.ndim != 0: # Called during final compute amax
Expand All @@ -197,4 +199,9 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor:
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
if self._fp8_scale_sweep_stride > 1:
candidates = fp8_values[:: self._fp8_scale_sweep_stride]
if candidates[-1] != fp8_values[-1]:
candidates = torch.cat([candidates, fp8_values[-1:]])
fp8_values = candidates
return fp8_values / 448.0
8 changes: 8 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
"start_multiplier, and stop_multiplier are ignored.",
)

fp8_scale_sweep_stride: int | None = ModeloptField(
default=1,
ge=1,
title="Stride for FP8 scale sweep candidates.",
description="Subsample every Nth valid FP8 E4M3 scale candidate when fp8_scale_sweep is True. "
"A value of 1 preserves the exhaustive sweep.",
)

distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
Expand Down
68 changes: 45 additions & 23 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,44 @@ def _has_expert_parallelism(module: nn.Module) -> bool:
return ps is not None and ps.expert_model_parallel_group.is_initialized()


def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete (some ranks have amax, others don't)."""
def _is_dynamic_block_quantizer(quantizer) -> bool:
block_sizes = getattr(quantizer, "block_sizes", None)
if isinstance(block_sizes, dict):
return block_sizes.get("type") == "dynamic"
return getattr(block_sizes, "type", None) == "dynamic"


def _iter_leaf_quantizers(quantizer):
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
_check_moe_calibration_complete(_q, parallel_state)
yield from _iter_leaf_quantizers(_q)
return
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
yield quantizer


def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete across distributed MoE ranks."""
for leaf_quantizer in _iter_leaf_quantizers(quantizer):
if _is_dynamic_block_quantizer(leaf_quantizer):
continue
has_amax = getattr(quantizer, "_amax", None) is not None
amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during calibration. "
"Increase --calib-size to ensure all experts see calibration data."

has_amax = getattr(leaf_quantizer, "_amax", None) is not None
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
continue
amax_states = DistributedProcessGroup.get_dist_syncd_obj(
has_amax, group, lambda objs: objs
)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during "
"calibration. Increase --calib-size to ensure all experts see calibration "
"data."
)


@torch.no_grad()
Expand Down Expand Up @@ -175,13 +193,13 @@ def max_calibrate(

def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
for leaf_quantizer in _iter_leaf_quantizers(quantizer):
if _is_dynamic_block_quantizer(leaf_quantizer):
continue
leaf_quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
leaf_quantizer.sync_amax_across_distributed_group(
parallel_state.expert_model_parallel_group
)
# TODO: create sync_bias_across_distributed_group

# Step 2:Sync amax across data parallelism
Expand Down Expand Up @@ -226,7 +244,7 @@ def sync_quantizer_amax_across_tp(
)
# Skip amax sync for INT4 / W4A8 block quantization
# Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale)
if getattr(quantizer.block_sizes, "type", None) == "dynamic":
if _is_dynamic_block_quantizer(quantizer):
return

if quantizer.axis in axes_for_sync and quantizer.amax is not None:
Expand Down Expand Up @@ -314,6 +332,7 @@ def mse_calibrate(
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
fp8_scale_sweep_stride: int = 1,
):
"""Calibrate the model using MSE-based amax search.

Expand All @@ -333,6 +352,8 @@ def mse_calibrate(
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).
fp8_scale_sweep_stride: Subsample every Nth FP8 E4M3 candidate when
fp8_scale_sweep is enabled. A value of 1 preserves exhaustive sweep.

See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
Expand Down Expand Up @@ -388,6 +409,7 @@ def mse_calibrate(
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
fp8_scale_sweep_stride=fp8_scale_sweep_stride,
)
continue

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json:
# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16
# HF names: mixer.experts.<N>.{up,down}_proj
# Megatron-Core names: mlp.experts.local_experts.<N>.linear_fc{1,2}
# - MoE shared experts: FP8 per-tensor
# HF names: mixer.shared_experts.{up,down}_proj
# Megatron-Core names: mlp.shared_experts.linear_fc{1,2}
# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor
# - KV cache: FP8
# - Attention linears ({q,k,v}_proj): BF16 (not quantized)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we double check attention out linear? IIRC, attention o_proj should be FP8.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

responded in slack, only 2/9 attention layers had o_proj FP8 in final Super NVFP4 ckpt, but we can always add it later to test if accuracy degradation is minimal

# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized)
# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized)
# - SSM cache: FP32 (can be set to FP16 in VLLM)
#
# Calibration: weight MSE with a stride-4 FP8-scale sweep over the e4m3 scale
# values. This keeps the FP8 static-scale path but uses a coarser candidate set.
metadata:
recipe_type: ptq
description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj
FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with stride-4 FP8 scale sweep.
quantize:
algorithm:
method: mse
fp8_scale_sweep: true
fp8_scale_sweep_stride: 4
quant_cfg:
- quantizer_name: '*'
enable: false

# MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale.
# Weight uses static block scales (chosen by MSE); activations stay dynamic.
# HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj.
- quantizer_name: '*mixer.experts.*weight_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*mixer.experts.*input_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
# Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}.
- quantizer_name: '*mlp.experts*weight_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*mlp.experts*input_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1

# MoE shared experts -> FP8 per-tensor.
# HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj.
- quantizer_name: '*mixer.shared_experts.*weight_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
- quantizer_name: '*mixer.shared_experts.*input_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
# Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}.
- quantizer_name: '*mlp.shared_experts*weight_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
- quantizer_name: '*mlp.shared_experts*input_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:

# Mamba mixer linears -> FP8 per-tensor.
- quantizer_name: '*mixer.in_proj*weight_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
- quantizer_name: '*mixer.in_proj*input_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
- quantizer_name: '*mixer.out_proj*weight_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:
- quantizer_name: '*mixer.out_proj*input_quantizer'
enable: true
cfg:
num_bits: e4m3
axis:

# KV cache -> FP8.
- quantizer_name: '*[kv]_bmm_quantizer'
enable: true
cfg:
num_bits: e4m3

# Stay BF16: lm_head, output projection, MoE routers/gates, MTP head.
# SSM state / mamba conv1d stay FP16.
Loading
Loading