Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
142 changes: 142 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).
"""

import torch
import triton
import triton.language as tl

from .nvfp4_quant import nvfp4_scalar_quant

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.

uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile: [BLOCKS_PER_PROGRAM, BLOCK_SIZE]
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w = tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
# block_amax = global_amax * c by construction; the FP8 round on the resulting
# scale is the identity for our candidate set, so we can skip the FP8 cast.
scale = c * global_amax / 6.0
w_q = nvfp4_scalar_quant(w, scale, BLOCK_SIZE)
diff = w - w_q
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the minimum scale candidate for tie breaking which multiple scales produce the same loss.

We are discussing a better alternative where we use the median scale for tie breaking.

What is your thought?
Cc @jenchen13 @Fridah-nv

best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
candidates: torch.Tensor | None = None,
blocks_per_program: int = 4,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).
candidates: Optional precomputed candidate tensor of shape ``[126]`` (must
be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.
blocks_per_program: Number of blocks each Triton program handles. Trades
launch overhead for register pressure; 4 is a reasonable default.

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
assert x.is_cuda, "nvfp4_fp8_scale_sweep requires a CUDA tensor"
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if candidates is None:
candidates = fp8_scale_candidates(x.device)
candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32)

n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = (triton.cdiv(n_blocks, blocks_per_program),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
BLOCKS_PER_PROGRAM=blocks_per_program,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return best_amax
81 changes: 80 additions & 1 deletion modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand Down Expand Up @@ -198,3 +198,82 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor:
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0


class TritonNVFP4MSECalibrator(NVFP4MSECalibrator):
"""Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE.

Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126
candidates in a single fused Triton kernel — one weight read instead of 126.

Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle.
This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where
the calibrator is collected once per weight and immediately consumed. For
activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`.
"""

def __init__(
self,
amax: torch.Tensor,
global_amax: torch.Tensor,
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
blocks_per_program: int = 4,
):
"""Initialize the Triton-fused NVFP4 MSE calibrator.

See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by
the kernel path but accepted for API parity.
"""
super().__init__(
amax=amax,
global_amax=global_amax,
axis=axis,
quant_func=quant_func,
error_func=error_func,
)
self._blocks_per_program = blocks_per_program
self._best_amax: torch.Tensor | None = None
Comment on lines +221 to +247
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? why not inherit from parent class?


@torch.no_grad()
def collect(self, x: torch.Tensor):
"""Run the fused FP8 sweep kernel and store the resulting per-block amax."""
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep

if self._best_amax is not None:
raise RuntimeError(
"TritonNVFP4MSECalibrator only supports a single collect() per cycle; "
"call reset() before collecting again."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

x = x.detach()
block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if self._initial_amax.numel() != n_blocks:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:

assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"

f"initial_amax.numel() ({self._initial_amax.numel()}) does not match "
f"the number of NVFP4 blocks ({n_blocks})."
)

best_amax_flat = nvfp4_fp8_scale_sweep(
x,
self._global_amax,
block_size=block_size,
blocks_per_program=self._blocks_per_program,
)
# Match the original shape/dtype of _initial_amax so downstream load_calib_amax
# behaves identically to the reference path.
self._best_amax = best_amax_flat.reshape(self._initial_amax.shape).to(
self._initial_amax.dtype
)

@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the per-block amax computed during ``collect``."""
return self._best_amax

def reset(self):
"""Reset the stored best amax."""
self._best_amax = None
super().reset()
Comment on lines +283 to +291
Copy link
Copy Markdown
Contributor

@realAsma realAsma May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 changes: 9 additions & 3 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Calibration utilities."""

import math
import os
import time
import warnings
from collections.abc import Callable
Expand All @@ -37,7 +38,7 @@
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
Expand Down Expand Up @@ -391,8 +392,13 @@ def mse_calibrate(
continue

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
# Replace calibrator with the fused Triton sweep kernel by default
# (single-shot collect, ~7-20x faster for the weight-MSE phase).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.

# Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference
# NVFP4MSECalibrator for debugging or numerics comparison.
use_triton = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0"
cls = TritonNVFP4MSECalibrator if use_triton else NVFP4MSECalibrator
module._calibrator = cls(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
Expand Down
Loading
Loading