Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
159 changes: 159 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).

Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
"""

import torch
import triton
import triton.language as tl

from .nvfp4_quant import fp4_round_magnitude

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.

uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
# would underfill the SMs.
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
]


@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile and pre-compute their absolute values once.
# The squared error is sign-invariant since FP4 quant preserves sign:
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
# Scales are guaranteed positive and finite (constructed from a positive candidate
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
# unnecessary apart from the global_amax == 0 case handled below.
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
scale = c * global_amax / 6.0
# Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is
# the same for every candidate, so any best_idx is fine.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
candidates: torch.Tensor | None = None,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).
candidates: Optional precomputed candidate tensor of shape ``[126]`` (must
be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor"
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if candidates is None:
candidates = fp8_scale_candidates(x.device)
candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32)

n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return best_amax
79 changes: 78 additions & 1 deletion modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand Down Expand Up @@ -198,3 +198,80 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor:
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0


class TritonNVFP4MSECalibrator(NVFP4MSECalibrator):
"""Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE.

Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126
candidates in a single fused Triton kernel — one weight read instead of 126.

Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle.
This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where
the calibrator is collected once per weight and immediately consumed. For
activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`.
"""

def __init__(
self,
amax: torch.Tensor,
global_amax: torch.Tensor,
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize the Triton-fused NVFP4 MSE calibrator.

See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by
the kernel path but accepted for API parity. Tile shape and ``num_warps`` are
autotuned by the kernel per ``N_BLOCKS``.
"""
super().__init__(
amax=amax,
global_amax=global_amax,
axis=axis,
quant_func=quant_func,
error_func=error_func,
)
self._best_amax: torch.Tensor | None = None

@torch.no_grad()
def collect(self, x: torch.Tensor):
"""Run the fused FP8 sweep kernel and store the resulting per-block amax."""
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep

if self._best_amax is not None:
raise RuntimeError(
"TritonNVFP4MSECalibrator only supports a single collect() per cycle; "
"call reset() before collecting again."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

x = x.detach()
block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if self._initial_amax.numel() != n_blocks:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:

assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"

f"initial_amax.numel() ({self._initial_amax.numel()}) does not match "
f"the number of NVFP4 blocks ({n_blocks})."
)

best_amax_flat = nvfp4_fp8_scale_sweep(
x,
self._global_amax,
block_size=block_size,
)
# Match the original shape/dtype of _initial_amax so downstream load_calib_amax
# behaves identically to the reference path.
self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to(
self._initial_amax.dtype
)

@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the per-block amax computed during ``collect``."""
return self._best_amax

def reset(self):
"""Reset the stored best amax."""
self._best_amax = None
super().reset()
12 changes: 9 additions & 3 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Calibration utilities."""

import math
import os
import time
import warnings
from collections.abc import Callable
Expand All @@ -37,7 +38,7 @@
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
Expand Down Expand Up @@ -391,8 +392,13 @@ def mse_calibrate(
continue

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
# Replace calibrator with the fused Triton sweep kernel by default
# (single-shot collect, ~7-20x faster for the weight-MSE phase).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.

# Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference
# NVFP4MSECalibrator for debugging or numerics comparison.
use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0"
cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator
module._calibrator = cls(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
Expand Down
Loading
Loading