From 4fbb18156194f4985960896ba27adae4648b34ae Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 20:57:52 +0000 Subject: [PATCH 1/3] [Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best amax directly. For our specific candidate set (FP8 representable values / 448) the FP8 round-trip on the per-block scale is the identity, so the kernel uses `scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton. Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`; set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging. Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator (8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to the reference for typical block counts; on multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative). Signed-off-by: Chenjie Luo --- .../kernels/quantization/gemm/__init__.py | 1 + .../quantization/gemm/nvfp4_fp8_sweep.py | 142 +++++++++++ modelopt/torch/quantization/calib/mse.py | 81 ++++++- modelopt/torch/quantization/model_calib.py | 12 +- .../test_nvfp4_fp8_sweep_kernel.py | 221 ++++++++++++++++++ 5 files changed, 453 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py create mode 100644 tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 39b07b4faa..70f729cffb 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -32,6 +32,7 @@ # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * from .fp8_kernel import * + from .nvfp4_fp8_sweep import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py new file mode 100644 index 0000000000..4fdeaf7c10 --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. + +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates +and emits the per-block ``best_amax`` directly. + +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on +the per-block scale is the identity, so the kernel can use +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). +""" + +import torch +import triton +import triton.language as tl + +from .nvfp4_quant import nvfp4_scalar_quant + +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] + + +def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: + """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + return fp8_values[valid_mask] / 448.0 + + +@triton.jit +def _fp8_scale_sweep_kernel( + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) + candidates_ptr, # [NUM_CANDIDATES] fp32 + global_amax_ptr, # scalar fp32 + best_amax_ptr, # [N_BLOCKS] fp32 output + N_BLOCKS, + BLOCK_SIZE: tl.constexpr, + NUM_CANDIDATES: tl.constexpr, + BLOCKS_PER_PROGRAM: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCKS_PER_PROGRAM + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) + block_mask = block_idx < N_BLOCKS + + # Load weights for this tile: [BLOCKS_PER_PROGRAM, BLOCK_SIZE] + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + elem_mask = block_mask[:, None] + w = tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32) + + global_amax = tl.load(global_amax_ptr).to(tl.float32) + + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) + + # Loop over the 126 FP8 candidates (compile-time unrolled). + for k in tl.static_range(NUM_CANDIDATES): + c = tl.load(candidates_ptr + k).to(tl.float32) + # block_amax = global_amax * c by construction; the FP8 round on the resulting + # scale is the identity for our candidate set, so we can skip the FP8 cast. + scale = c * global_amax / 6.0 + w_q = nvfp4_scalar_quant(w, scale, BLOCK_SIZE) + diff = w - w_q + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] + is_better = loss < best_loss + best_loss = tl.where(is_better, loss, best_loss) + best_idx = tl.where(is_better, k, best_idx) + + # Map each block's winning candidate index back to its amax = global_amax * c[best]. + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) + best_amax = global_amax * best_c + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) + + +def nvfp4_fp8_scale_sweep( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + candidates: torch.Tensor | None = None, + blocks_per_program: int = 4, +) -> torch.Tensor: + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. + + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into + a single Triton kernel: every block's weight elements are loaded once, all 126 + candidates are evaluated in registers, and the running argmin is kept inline. + + Args: + x: Weight tensor on CUDA. Total element count must be divisible by + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). + block_size: NVFP4 block size (typically 16). + candidates: Optional precomputed candidate tensor of shape ``[126]`` (must + be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. + blocks_per_program: Number of blocks each Triton program handles. Trades + launch overhead for register pressure; 4 is a reasonable default. + + Returns: + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. + """ + assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor" + if x.numel() % block_size != 0: + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") + + if candidates is None: + candidates = fp8_scale_candidates(x.device) + candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + + n_blocks = x.numel() // block_size + x_flat = x.contiguous().view(-1) + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) + + grid = (triton.cdiv(n_blocks, blocks_per_program),) + with torch.cuda.device(x.device): + _fp8_scale_sweep_kernel[grid]( + x_flat, + candidates, + global_amax_f32, + best_amax, + n_blocks, + BLOCK_SIZE=block_size, + NUM_CANDIDATES=int(candidates.numel()), + BLOCKS_PER_PROGRAM=blocks_per_program, + ) + return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e77..a879790cc4 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -24,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -198,3 +198,82 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor: valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] return fp8_values / 448.0 + + +class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): + """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. + + Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 + candidates in a single fused Triton kernel — one weight read instead of 126. + + Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. + This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where + the calibrator is collected once per weight and immediately consumed. For + activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + """ + + def __init__( + self, + amax: torch.Tensor, + global_amax: torch.Tensor, + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + blocks_per_program: int = 4, + ): + """Initialize the Triton-fused NVFP4 MSE calibrator. + + See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by + the kernel path but accepted for API parity. + """ + super().__init__( + amax=amax, + global_amax=global_amax, + axis=axis, + quant_func=quant_func, + error_func=error_func, + ) + self._blocks_per_program = blocks_per_program + self._best_amax: torch.Tensor | None = None + + @torch.no_grad() + def collect(self, x: torch.Tensor): + """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" + from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep + + if self._best_amax is not None: + raise RuntimeError( + "TritonNVFP4MSECalibrator only supports a single collect() per cycle; " + "call reset() before collecting again." + ) + + x = x.detach() + block_size = x.shape[-1] + n_blocks = x.numel() // block_size + if self._initial_amax.numel() != n_blocks: + raise ValueError( + f"initial_amax.numel() ({self._initial_amax.numel()}) does not match " + f"the number of NVFP4 blocks ({n_blocks})." + ) + + best_amax_flat = nvfp4_fp8_scale_sweep( + x, + self._global_amax, + block_size=block_size, + blocks_per_program=self._blocks_per_program, + ) + # Match the original shape/dtype of _initial_amax so downstream load_calib_amax + # behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to( + self._initial_amax.dtype + ) + + @torch.no_grad() + def compute_amax(self, verbose: bool = False): + """Return the per-block amax computed during ``collect``.""" + return self._best_amax + + def reset(self): + """Reset the stored best amax.""" + self._best_amax = None + super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 4ce0f62a75..62fadbb51a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,7 @@ """Calibration utilities.""" import math +import os import time import warnings from collections.abc import Callable @@ -37,7 +38,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -391,8 +392,13 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with NVFP4MSECalibrator - module._calibrator = NVFP4MSECalibrator( + # Replace calibrator with the fused Triton sweep kernel by default + # (single-shot collect, ~7-20x faster for the weight-MSE phase). + # Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference + # NVFP4MSECalibrator for debugging or numerics comparison. + use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator + module._calibrator = cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py new file mode 100644 index 0000000000..f1ac5b7f24 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity + speedup tests for the fused NVFP4 FP8 scale sweep Triton kernel. + +Compares :class:`TritonNVFP4MSECalibrator` against the reference +:class:`NVFP4MSECalibrator` on the same inputs and asserts the resulting per-block +amax tensors are bit-identical. Also reports a wall-clock speedup number for the +weight-MSE search step on a representative LLM-sized weight. +""" + +import time + +import pytest +import torch +from conftest import requires_triton + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator, TritonNVFP4MSECalibrator +from modelopt.torch.quantization.tensor_quant import static_blockwise_fp4_fake_quant + +BLOCK_SIZE = 16 + + +def _reference_quant_func(global_amax): + """Reference NVFP4 fake-quant matching what ``mse_calibrate`` plumbs in.""" + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + return quant_func + + +def _run_reference(x, per_block_amax, global_amax): + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +def _run_triton(x, per_block_amax, global_amax): + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + quant_func=_reference_quant_func(global_amax), + ) + cal.collect(x) + return cal.compute_amax() + + +@requires_triton +@pytest.mark.parametrize("seed", [0, 1, 2]) +@pytest.mark.parametrize("num_blocks", [4, 64, 1024]) +def test_parity_random_weights(seed, num_blocks): + """Triton sweep must produce the exact same per-block amax as the reference.""" + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + + assert ref.shape == tri.shape + # Both pick from the same 126-element discrete candidate set, so any disagreement + # would show up as a non-zero diff (not a small float epsilon). Demand exact match. + assert torch.equal(ref, tri), ( + f"Triton sweep diverged from reference: max |diff| = " + f"{(ref - tri).abs().max().item():.3e}, " + f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}" + ) + + +@requires_triton +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_parity_dtypes(dtype): + """Sweep must agree across the dtypes supported by the NVFP4 quantizer.""" + torch.manual_seed(42) + device = "cuda" + num_blocks = 256 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype) + # Promote to fp32 for the per-block amax (matches what max_calibrate produces). + per_block_amax = x.float().abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref = _run_reference(x, per_block_amax, global_amax) + tri = _run_triton(x, per_block_amax, global_amax) + assert torch.equal(ref, tri) + + +@requires_triton +def test_quantized_output_matches(): + """Round-tripping x through the chosen amax should give the same fake-quant result.""" + torch.manual_seed(7) + device = "cuda" + num_blocks = 128 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + assert torch.equal(ref_xq, tri_xq) + + +@requires_triton +def test_reset_allows_recollect(): + torch.manual_seed(0) + device = "cuda" + num_blocks = 32 + x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + cal = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal.collect(x) + first = cal.compute_amax().clone() + + with pytest.raises(RuntimeError, match="single collect"): + cal.collect(x) + + cal.reset() + # After reset the calibrator's _initial_amax has been freed; reconstruct. + cal2 = TritonNVFP4MSECalibrator( + amax=per_block_amax, + axis=0, + global_amax=global_amax, + ) + cal2.collect(x) + assert torch.equal(first, cal2.compute_amax()) + + +def _bench(fn, warmup=2, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters + + +@requires_triton +def test_speedup_report(capsys): + """Sanity-check that the Triton path is meaningfully faster on a realistic weight. + + Uses an 8192 x 4096 weight (~33M elements, ~2M NVFP4 blocks) — roughly the size + of an LLM attention/MLP projection. Reports the speedup; does not gate on a + minimum factor (kernel timing is noisy on shared CI), but does require parity + on the chosen amax. + """ + torch.manual_seed(123) + device = "cuda" + cout, cin = 8192, 4096 + x = torch.randn(cout, cin // BLOCK_SIZE, BLOCK_SIZE, device=device, dtype=torch.float32) + x = x.reshape(-1, BLOCK_SIZE) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + ref_amax = _run_reference(x, per_block_amax, global_amax) + tri_amax = _run_triton(x, per_block_amax, global_amax) + # Bit-equality across millions of blocks isn't guaranteed: when two adjacent FP8 + # candidates yield near-identical per-block MSE (within fp32 noise), the reference's + # CUDA fake_e4m3fy path and our Triton inline math can break ties differently. Demand + # instead that the Triton choice produces a per-block MSE within fp32 epsilon of the + # reference's choice. + n_blocks = ref_amax.numel() + n_diff = int((ref_amax != tri_amax).sum()) + if n_diff: + ref_xq = static_blockwise_fp4_fake_quant(x, ref_amax, global_amax) + tri_xq = static_blockwise_fp4_fake_quant(x, tri_amax, global_amax) + per_block_mse_ref = (x - ref_xq).pow(2).sum(dim=-1) + per_block_mse_tri = (x - tri_xq).pow(2).sum(dim=-1) + # Reference is the formal argmin, so triton's loss should be ≥ reference's. + # Allow at most 1e-5 relative gap on differing blocks (observed ~1e-7 in practice). + rel_gap = (per_block_mse_tri - per_block_mse_ref).abs() / per_block_mse_ref.clamp_min(1e-12) + worst = rel_gap.max().item() + assert worst < 1e-5, ( + f"{n_diff}/{n_blocks} blocks disagree with worst relative MSE gap {worst:.3e} " + "— exceeds tie-break tolerance" + ) + + ref_t = _bench(lambda: _run_reference(x, per_block_amax, global_amax)) + tri_t = _bench(lambda: _run_triton(x, per_block_amax, global_amax)) + speedup = ref_t / tri_t + + # Force-print regardless of pytest capture mode. + with capsys.disabled(): + n_blocks = x.numel() // BLOCK_SIZE + print( + f"\n[NVFP4 FP8 sweep] weight=({cout},{cin}) " + f"n_blocks={n_blocks} block_size={BLOCK_SIZE}\n" + f" reference NVFP4MSECalibrator: {ref_t * 1e3:8.2f} ms\n" + f" triton TritonNVFP4MSECalibrator: {tri_t * 1e3:8.2f} ms\n" + f" speedup: {speedup:.1f}x" + ) From 60406070b1da433a26c8b7018f1e5a73473d0bec Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 21:11:09 +0000 Subject: [PATCH 2/3] [Quantization] Autotune NVFP4 FP8 sweep kernel; drop sign-where in inner loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-on optimizations to the fused FP8 scale sweep kernel: 1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300 showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are included to cover small/medium/large N_BLOCKS without flooding compile time. 2. Drop the sign-handling tl.where: since FP4 quantization preserves sign, (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and skips one tl.where + negation per element per candidate. Result on the same 8192x4096 weight (~2M blocks) on B300: reference NVFP4MSECalibrator: 176.68 ms triton TritonNVFP4MSECalibrator: 4.23 ms speedup: 41.8x (was 7.4x) This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms), so the kernel is now near saturation and further wins would need an algorithmic change (candidate pruning, etc.). Signed-off-by: Chenjie Luo --- .../quantization/gemm/nvfp4_fp8_sweep.py | 41 +++++++++++++------ modelopt/torch/quantization/calib/mse.py | 6 +-- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 4fdeaf7c10..8492a9c93a 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -24,13 +24,15 @@ the per-block scale is the identity, so the kernel can use ``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). + +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. """ import torch import triton import triton.language as tl -from .nvfp4_quant import nvfp4_scalar_quant +from .nvfp4_quant import fp4_round_magnitude __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] @@ -43,6 +45,18 @@ def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: return fp8_values[valid_mask] / 448.0 +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 +# would underfill the SMs. +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), +] + + +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) @triton.jit def _fp8_scale_sweep_kernel( x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) @@ -59,10 +73,13 @@ def _fp8_scale_sweep_kernel( block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) block_mask = block_idx < N_BLOCKS - # Load weights for this tile: [BLOCKS_PER_PROGRAM, BLOCK_SIZE] + # Load weights for this tile and pre-compute their absolute values once. + # The squared error is sign-invariant since FP4 quant preserves sign: + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 + # so we never need ``w`` itself again, dropping a tl.where + negation per element. elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] elem_mask = block_mask[:, None] - w = tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32) + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) global_amax = tl.load(global_amax_ptr).to(tl.float32) @@ -70,13 +87,17 @@ def _fp8_scale_sweep_kernel( best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) # Loop over the 126 FP8 candidates (compile-time unrolled). + # Scales are guaranteed positive and finite (constructed from a positive candidate + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is + # unnecessary apart from the global_amax == 0 case handled below. for k in tl.static_range(NUM_CANDIDATES): c = tl.load(candidates_ptr + k).to(tl.float32) - # block_amax = global_amax * c by construction; the FP8 round on the resulting - # scale is the identity for our candidate set, so we can skip the FP8 cast. scale = c * global_amax / 6.0 - w_q = nvfp4_scalar_quant(w, scale, BLOCK_SIZE) - diff = w - w_q + # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is + # the same for every candidate, so any best_idx is fine. + scale_safe = tl.where(scale == 0.0, 1.0, scale) + q_mag = fp4_round_magnitude(w_abs / scale_safe) + diff = w_abs - q_mag * scale loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] is_better = loss < best_loss best_loss = tl.where(is_better, loss, best_loss) @@ -93,7 +114,6 @@ def nvfp4_fp8_scale_sweep( global_amax: torch.Tensor, block_size: int = 16, candidates: torch.Tensor | None = None, - blocks_per_program: int = 4, ) -> torch.Tensor: """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. @@ -108,8 +128,6 @@ def nvfp4_fp8_scale_sweep( block_size: NVFP4 block size (typically 16). candidates: Optional precomputed candidate tensor of shape ``[126]`` (must be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. - blocks_per_program: Number of blocks each Triton program handles. Trades - launch overhead for register pressure; 4 is a reasonable default. Returns: ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. @@ -127,7 +145,7 @@ def nvfp4_fp8_scale_sweep( global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) - grid = (triton.cdiv(n_blocks, blocks_per_program),) + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) with torch.cuda.device(x.device): _fp8_scale_sweep_kernel[grid]( x_flat, @@ -137,6 +155,5 @@ def nvfp4_fp8_scale_sweep( n_blocks, BLOCK_SIZE=block_size, NUM_CANDIDATES=int(candidates.numel()), - BLOCKS_PER_PROGRAM=blocks_per_program, ) return best_amax diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index a879790cc4..7471ec23bb 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -219,12 +219,12 @@ def __init__( axis: int | tuple | list | None = None, quant_func: Callable | None = None, error_func: Callable | None = None, - blocks_per_program: int = 4, ): """Initialize the Triton-fused NVFP4 MSE calibrator. See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by - the kernel path but accepted for API parity. + the kernel path but accepted for API parity. Tile shape and ``num_warps`` are + autotuned by the kernel per ``N_BLOCKS``. """ super().__init__( amax=amax, @@ -233,7 +233,6 @@ def __init__( quant_func=quant_func, error_func=error_func, ) - self._blocks_per_program = blocks_per_program self._best_amax: torch.Tensor | None = None @torch.no_grad() @@ -260,7 +259,6 @@ def collect(self, x: torch.Tensor): x, self._global_amax, block_size=block_size, - blocks_per_program=self._blocks_per_program, ) # Match the original shape/dtype of _initial_amax so downstream load_calib_amax # behaves identically to the reference path. From bd4fc3a651e04bd3df4a3ae07c9a513acb12ecff Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 4 May 2026 22:05:02 +0000 Subject: [PATCH 3/3] [Quantization] Address PR review feedback on FP8 sweep kernel Addresses review comments on PR #1387: - TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape / dtype / n_blocks of the initial amax are stashed in __init__, so collect() no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2 assertion in collect() since the weight quantizer always reshapes upstream. - nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert (which is stripped by python -O): rejects non-CUDA tensors, non-positive block_size, and empty / non-1D candidates with ValueError. Skips the per-element finite/positive check on candidates since it would scan a 126- entry tensor on every kernel call. - mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of the per-quantizer loop and resolves to the calibrator class once. - Updates test_reset_allows_recollect to verify the new reuse contract; adds test_input_validation covering the new ValueErrors. The duplicate fp8_scale_candidates implementation in the kernel file and NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating would force the reference path to import from the kernel module, which is gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity test exercises both paths against each other. Signed-off-by: Chenjie Luo --- .../quantization/gemm/nvfp4_fp8_sweep.py | 9 +++- modelopt/torch/quantization/calib/mse.py | 36 ++++++++++----- modelopt/torch/quantization/model_calib.py | 13 +++--- .../test_nvfp4_fp8_sweep_kernel.py | 44 +++++++++++++++---- 4 files changed, 74 insertions(+), 28 deletions(-) diff --git a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py index 8492a9c93a..4b9f19837f 100644 --- a/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py +++ b/modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py @@ -132,13 +132,20 @@ def nvfp4_fp8_scale_sweep( Returns: ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. """ - assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor" + if not x.is_cuda: + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") + if not isinstance(block_size, int) or block_size <= 0: + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") if x.numel() % block_size != 0: raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") if candidates is None: candidates = fp8_scale_candidates(x.device) candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) + if candidates.ndim != 1 or candidates.numel() == 0: + raise ValueError( + f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." + ) n_blocks = x.numel() // block_size x_flat = x.contiguous().view(-1) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 7471ec23bb..fff0b8af1b 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -192,7 +192,12 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: return torch.ones_like(self._initial_amax) * self._global_amax * candidates def _generate_candidates(self, device: torch.device) -> torch.Tensor: - """Generate 126 valid FP8 E4M3 scale candidates.""" + """Generate 126 valid FP8 E4M3 scale candidates. + + Kept in sync with ``fp8_scale_candidates`` in + ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 + spec is fixed, and the parity test exercises both paths against each other. + """ uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) @@ -210,6 +215,7 @@ class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where the calibrator is collected once per weight and immediately consumed. For activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. + Call :meth:`reset` to free internal state and re-enable :meth:`collect`. """ def __init__( @@ -233,6 +239,11 @@ def __init__( quant_func=quant_func, error_func=error_func, ) + # Stash shape metadata so collect() can keep working after reset() releases + # the (potentially large) _initial_amax buffer. + self._initial_amax_shape = tuple(amax.shape) + self._initial_amax_dtype = amax.dtype + self._n_blocks = int(amax.numel()) self._best_amax: torch.Tensor | None = None @torch.no_grad() @@ -242,17 +253,20 @@ def collect(self, x: torch.Tensor): if self._best_amax is not None: raise RuntimeError( - "TritonNVFP4MSECalibrator only supports a single collect() per cycle; " - "call reset() before collecting again." + "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " + "discard the previous result before collecting again." ) x = x.detach() + # The weight quantizer reshapes its input to [n_blocks, block_size] before + # calling collect (see TensorQuantizer._process_for_blockquant). + assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." block_size = x.shape[-1] n_blocks = x.numel() // block_size - if self._initial_amax.numel() != n_blocks: + if n_blocks != self._n_blocks: raise ValueError( - f"initial_amax.numel() ({self._initial_amax.numel()}) does not match " - f"the number of NVFP4 blocks ({n_blocks})." + f"initial amax.numel() ({self._n_blocks}) does not match the number " + f"of NVFP4 blocks in x ({n_blocks})." ) best_amax_flat = nvfp4_fp8_scale_sweep( @@ -260,10 +274,10 @@ def collect(self, x: torch.Tensor): self._global_amax, block_size=block_size, ) - # Match the original shape/dtype of _initial_amax so downstream load_calib_amax - # behaves identically to the reference path. - self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to( - self._initial_amax.dtype + # Match the original shape/dtype of the initial amax so downstream + # load_calib_amax behaves identically to the reference path. + self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( + self._initial_amax_dtype ) @torch.no_grad() @@ -272,6 +286,6 @@ def compute_amax(self, verbose: bool = False): return self._best_amax def reset(self): - """Reset the stored best amax.""" + """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" self._best_amax = None super().reset() diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 62fadbb51a..cd86ff1c72 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -355,6 +355,11 @@ def mse_calibrate( weight_quantizers = [] seen_modules = set() + # Triton-fused FP8 sweep is on by default for NVFP4 static quant; set + # MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging. + use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" + nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): @@ -392,13 +397,7 @@ def mse_calibrate( continue if fp8_scale_sweep and is_nvfp4_static: - # Replace calibrator with the fused Triton sweep kernel by default - # (single-shot collect, ~7-20x faster for the weight-MSE phase). - # Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference - # NVFP4MSECalibrator for debugging or numerics comparison. - use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" - cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator - module._calibrator = cls( + module._calibrator = nvfp4_calibrator_cls( amax=initial_amax, axis=module._calibrator._axis, global_amax=module.global_amax, diff --git a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py index f1ac5b7f24..c25867d832 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py +++ b/tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py @@ -140,18 +140,44 @@ def test_reset_allows_recollect(): cal.collect(x) first = cal.compute_amax().clone() - with pytest.raises(RuntimeError, match="single collect"): + # collect() is one-shot per cycle until reset() is called. + with pytest.raises(RuntimeError, match="one-shot"): cal.collect(x) cal.reset() - # After reset the calibrator's _initial_amax has been freed; reconstruct. - cal2 = TritonNVFP4MSECalibrator( - amax=per_block_amax, - axis=0, - global_amax=global_amax, - ) - cal2.collect(x) - assert torch.equal(first, cal2.compute_amax()) + # After reset, the same calibrator instance can be re-used. + cal.collect(x) + assert torch.equal(first, cal.compute_amax()) + + +@requires_triton +def test_input_validation(): + """``nvfp4_fp8_scale_sweep`` should reject malformed inputs cleanly.""" + from modelopt.torch.kernels.quantization.gemm import fp8_scale_candidates, nvfp4_fp8_scale_sweep + + device = "cuda" + x = torch.randn(64, BLOCK_SIZE, device=device) + g = x.abs().amax() + + # CPU tensor → ValueError (not bare AssertionError). + with pytest.raises(ValueError, match="CUDA"): + nvfp4_fp8_scale_sweep(x.cpu(), g.cpu()) + + # block_size <= 0. + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=0) + with pytest.raises(ValueError, match="block_size"): + nvfp4_fp8_scale_sweep(x, g, block_size=-1) + + # Non-divisible numel. + with pytest.raises(ValueError, match="not divisible"): + nvfp4_fp8_scale_sweep(x, g, block_size=15) + + # Empty / wrong-rank candidates. + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=torch.empty(0, device=device)) + with pytest.raises(ValueError, match="non-empty 1-D"): + nvfp4_fp8_scale_sweep(x, g, candidates=fp8_scale_candidates(device).reshape(2, -1)) def _bench(fn, warmup=2, iters=5):