-
Notifications
You must be signed in to change notification settings - Fork 382
[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. | ||
|
|
||
| Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single | ||
| kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates | ||
| and emits the per-block ``best_amax`` directly. | ||
|
|
||
| The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see | ||
| :func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on | ||
| the per-block scale is the identity, so the kernel can use | ||
| ``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it | ||
| runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). | ||
|
|
||
| Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. | ||
| """ | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from .nvfp4_quant import fp4_round_magnitude | ||
|
|
||
| __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] | ||
|
|
||
|
|
||
| def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: | ||
| """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" | ||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| return fp8_values[valid_mask] / 448.0 | ||
|
|
||
|
|
||
| # Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: | ||
| # BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms | ||
| # The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 | ||
| # would underfill the SMs. | ||
| _FP8_SWEEP_AUTOTUNE_CONFIGS = [ | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), | ||
| ] | ||
|
|
||
|
|
||
| @triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) | ||
| @triton.jit | ||
| def _fp8_scale_sweep_kernel( | ||
| x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) | ||
| candidates_ptr, # [NUM_CANDIDATES] fp32 | ||
| global_amax_ptr, # scalar fp32 | ||
| best_amax_ptr, # [N_BLOCKS] fp32 output | ||
| N_BLOCKS, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| NUM_CANDIDATES: tl.constexpr, | ||
| BLOCKS_PER_PROGRAM: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| block_start = pid * BLOCKS_PER_PROGRAM | ||
| block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) | ||
| block_mask = block_idx < N_BLOCKS | ||
|
|
||
| # Load weights for this tile and pre-compute their absolute values once. | ||
| # The squared error is sign-invariant since FP4 quant preserves sign: | ||
| # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 | ||
| # so we never need ``w`` itself again, dropping a tl.where + negation per element. | ||
| elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] | ||
| elem_mask = block_mask[:, None] | ||
| w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) | ||
|
|
||
| global_amax = tl.load(global_amax_ptr).to(tl.float32) | ||
|
|
||
| best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) | ||
| best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) | ||
|
|
||
| # Loop over the 126 FP8 candidates (compile-time unrolled). | ||
| # Scales are guaranteed positive and finite (constructed from a positive candidate | ||
| # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is | ||
| # unnecessary apart from the global_amax == 0 case handled below. | ||
| for k in tl.static_range(NUM_CANDIDATES): | ||
| c = tl.load(candidates_ptr + k).to(tl.float32) | ||
| scale = c * global_amax / 6.0 | ||
| # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is | ||
| # the same for every candidate, so any best_idx is fine. | ||
| scale_safe = tl.where(scale == 0.0, 1.0, scale) | ||
| q_mag = fp4_round_magnitude(w_abs / scale_safe) | ||
| diff = w_abs - q_mag * scale | ||
| loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] | ||
| is_better = loss < best_loss | ||
| best_loss = tl.where(is_better, loss, best_loss) | ||
| best_idx = tl.where(is_better, k, best_idx) | ||
|
|
||
| # Map each block's winning candidate index back to its amax = global_amax * c[best]. | ||
| best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) | ||
| best_amax = global_amax * best_c | ||
| tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) | ||
|
|
||
|
|
||
| def nvfp4_fp8_scale_sweep( | ||
| x: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| block_size: int = 16, | ||
| candidates: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. | ||
|
|
||
| Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into | ||
| a single Triton kernel: every block's weight elements are loaded once, all 126 | ||
| candidates are evaluated in registers, and the running argmin is kept inline. | ||
|
|
||
| Args: | ||
| x: Weight tensor on CUDA. Total element count must be divisible by | ||
| ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. | ||
| global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). | ||
| block_size: NVFP4 block size (typically 16). | ||
| candidates: Optional precomputed candidate tensor of shape ``[126]`` (must | ||
| be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. | ||
|
|
||
| Returns: | ||
| ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. | ||
| """ | ||
| assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor" | ||
| if x.numel() % block_size != 0: | ||
| raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| if candidates is None: | ||
| candidates = fp8_scale_candidates(x.device) | ||
| candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) | ||
|
|
||
| n_blocks = x.numel() // block_size | ||
| x_flat = x.contiguous().view(-1) | ||
| global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) | ||
| best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) | ||
| with torch.cuda.device(x.device): | ||
| _fp8_scale_sweep_kernel[grid]( | ||
| x_flat, | ||
| candidates, | ||
| global_amax_f32, | ||
| best_amax, | ||
| n_blocks, | ||
| BLOCK_SIZE=block_size, | ||
| NUM_CANDIDATES=int(candidates.numel()), | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return best_amax | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
| from .. import utils as quant_utils | ||
| from .calibrator import _Calibrator | ||
|
|
||
| __all__ = ["MseCalibrator", "NVFP4MSECalibrator"] | ||
| __all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] | ||
|
|
||
|
|
||
| class MseCalibrator(_Calibrator): | ||
|
|
@@ -198,3 +198,80 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor: | |
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| fp8_values = fp8_values[valid_mask] | ||
| return fp8_values / 448.0 | ||
|
|
||
|
|
||
| class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): | ||
| """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. | ||
|
|
||
| Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 | ||
| candidates in a single fused Triton kernel — one weight read instead of 126. | ||
|
|
||
| Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. | ||
| This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where | ||
| the calibrator is collected once per weight and immediately consumed. For | ||
| activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| amax: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| axis: int | tuple | list | None = None, | ||
| quant_func: Callable | None = None, | ||
| error_func: Callable | None = None, | ||
| ): | ||
| """Initialize the Triton-fused NVFP4 MSE calibrator. | ||
|
|
||
| See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by | ||
| the kernel path but accepted for API parity. Tile shape and ``num_warps`` are | ||
| autotuned by the kernel per ``N_BLOCKS``. | ||
| """ | ||
| super().__init__( | ||
| amax=amax, | ||
| global_amax=global_amax, | ||
| axis=axis, | ||
| quant_func=quant_func, | ||
| error_func=error_func, | ||
| ) | ||
| self._best_amax: torch.Tensor | None = None | ||
|
|
||
| @torch.no_grad() | ||
| def collect(self, x: torch.Tensor): | ||
| """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" | ||
| from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep | ||
|
|
||
| if self._best_amax is not None: | ||
| raise RuntimeError( | ||
| "TritonNVFP4MSECalibrator only supports a single collect() per cycle; " | ||
| "call reset() before collecting again." | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| x = x.detach() | ||
| block_size = x.shape[-1] | ||
| n_blocks = x.numel() // block_size | ||
| if self._initial_amax.numel() != n_blocks: | ||
| raise ValueError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nit: assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape" |
||
| f"initial_amax.numel() ({self._initial_amax.numel()}) does not match " | ||
| f"the number of NVFP4 blocks ({n_blocks})." | ||
| ) | ||
|
|
||
| best_amax_flat = nvfp4_fp8_scale_sweep( | ||
| x, | ||
| self._global_amax, | ||
| block_size=block_size, | ||
| ) | ||
| # Match the original shape/dtype of _initial_amax so downstream load_calib_amax | ||
| # behaves identically to the reference path. | ||
| self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to( | ||
| self._initial_amax.dtype | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| def compute_amax(self, verbose: bool = False): | ||
| """Return the per-block amax computed during ``collect``.""" | ||
| return self._best_amax | ||
|
|
||
| def reset(self): | ||
| """Reset the stored best amax.""" | ||
| self._best_amax = None | ||
| super().reset() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| """Calibration utilities.""" | ||
|
|
||
| import math | ||
| import os | ||
| import time | ||
| import warnings | ||
| from collections.abc import Callable | ||
|
|
@@ -37,7 +38,7 @@ | |
| from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState | ||
| from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method | ||
|
|
||
| from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator | ||
| from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator | ||
| from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context | ||
| from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer | ||
| from .utils import ( | ||
|
|
@@ -391,8 +392,13 @@ def mse_calibrate( | |
| continue | ||
|
|
||
| if fp8_scale_sweep and is_nvfp4_static: | ||
| # Replace calibrator with NVFP4MSECalibrator | ||
| module._calibrator = NVFP4MSECalibrator( | ||
| # Replace calibrator with the fused Triton sweep kernel by default | ||
| # (single-shot collect, ~7-20x faster for the weight-MSE phase). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The env var check |
||
| # Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference | ||
| # NVFP4MSECalibrator for debugging or numerics comparison. | ||
| use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" | ||
| cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator | ||
| module._calibrator = cls( | ||
| amax=initial_amax, | ||
| axis=module._calibrator._axis, | ||
| global_amax=module.global_amax, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor duplication: this function reproduces the same logic as
NVFP4MSECalibrator._generate_candidates()incalib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.