Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

FP8_WEIGHT_BLOCK_SIZE = [128, 128]


def should_quantize_to_fp8(name: str, tensor: torch.Tensor) -> bool:
"""Check whether a HuggingFace-named weight should be block-quantized to FP8.

Matches the same set of parameters that vLLM quantizes (linear-layer
weights only). Embeddings, layernorms, biases, and lm_head are excluded.
"""
if tensor.dim() != 2:
return False
if not name.endswith(".weight"):
return False
lower = name.lower()
if any(kw in lower for kw in ("norm", "embed", "lm_head")):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit too hacky; is it possible to obtain the list of param names to-be-quantized from the is_fp8_weight function in vllm side? This info can be synced one time and reused for all consequent steps

return False
return True


def cast_tensor_to_fp8_blockwise(
data_hp: torch.Tensor,
weight_block_size: list[int],
use_pow2_scale: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Block-wise FP8 (E4M3) quantization — standalone, no vLLM dependencies.

Args:
data_hp: 2-D high-precision weight tensor (any float dtype).
weight_block_size: [block_rows, block_cols], e.g. [128, 128].
use_pow2_scale: If True, round scale factors to powers of two.

Returns:
(fp8_data, descale) where fp8_data has dtype float8_e4m3fn and
descale is float32 with shape (blk_m, blk_n, 1).
"""
assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported"

block_size0, block_size1 = weight_block_size
shape_before_padding = data_hp.shape

if data_hp.shape[0] % block_size0 != 0 or data_hp.shape[1] % block_size1 != 0:
pad0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0
pad1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1
data_hp = torch.nn.functional.pad(
data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1]
)

max_dtype = torch.finfo(torch.float8_e4m3fn).max
original_shape = data_hp.shape
blk_m = data_hp.shape[0] // block_size0
blk_n = data_hp.shape[1] // block_size1

assert block_size0 == block_size1
data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1)
data_hp = data_hp.permute(0, 2, 1, 3)
data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2)

max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True)

if use_pow2_scale:
descale = max_abs / max_dtype
exponent = torch.ceil(torch.log2(descale))
exponent = torch.clamp(exponent, min=-127, max=127) + 127
exponent = exponent.to(torch.uint8)
scale_fp = torch.where(
exponent == 0,
1.0,
torch.exp2(127 - exponent.to(torch.float32)),
)
descale_fp = torch.reciprocal(scale_fp)
else:
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
descale_fp = torch.reciprocal(scale_fp)
Comment on lines +71 to +75
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

NaN values are not handled in scale computation.

The linear scale path handles max_abs == 0 and max_abs == inf, but NaN values would propagate silently. If any block contains NaN, both scale_fp and the resulting fp8_data would be NaN.

🛡️ Suggested fix to handle NaN
     else:
         scale_fp = max_dtype / max_abs
         scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
         scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
+        scale_fp = torch.where(torch.isnan(max_abs), 1.0, scale_fp)
         descale_fp = torch.reciprocal(scale_fp)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines
87 - 91, The scale computation doesn't guard against NaN in max_abs which will
produce NaN scale_fp and downstream fp8_data; update the branch that computes
scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.


data_lp = torch.clamp(data_hp * scale_fp, min=-max_dtype, max=max_dtype)
fp_data = data_lp.to(torch.float8_e4m3fn)

fp_data = (
fp_data.reshape(blk_m, blk_n, block_size0, block_size1)
.permute(0, 2, 1, 3)
.reshape(original_shape)
)

if original_shape != shape_before_padding:
fp_data = fp_data[: shape_before_padding[0], : shape_before_padding[1]]

return fp_data, descale_fp


def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]:
"""Get vLLM-compatible parameter names for Q/K/V FP8 scales.
Expand Down
50 changes: 23 additions & 27 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,9 @@ def update_weights_via_ipc_zmq(self) -> bool:
assert offset == used_bytes, (
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
)
# Load weights into the model
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=weights)
# Load weights into the model (re-uses the shared helper that
# skips FP8 quantization when weights arrive pre-quantized).
self._load_model_weights(weights, self.model_runner)

torch.cuda.current_stream().synchronize()

Expand All @@ -207,6 +202,25 @@ def update_weights_via_ipc_zmq(self) -> bool:
)
return False

@staticmethod
def _load_model_weights(weights, model_runner):
"""Load model weights, skipping FP8 quantization when pre-quantized.

If the incoming batch already contains ``_scale_inv`` entries (i.e.
weights were quantized on the training worker), we feed them straight
to ``model.load_weights`` and avoid a redundant quantization pass.
"""
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(model_runner.vllm_config):
pre_quantized = any(name.endswith("_scale_inv") for name, _ in weights)
if pre_quantized:
model_runner.model.load_weights(weights=weights)
else:
fp8.load_weights(weights, model_runner)
else:
model_runner.model.load_weights(weights=weights)

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_collective"
)
Expand All @@ -217,25 +231,7 @@ def update_weights_from_collective(self) -> bool:
"Please call prepare_refit_info when initializing the worker."
)

def _load_model_weights(weights, model_runner):
"""Load model weights.

Args:
weights: List[(name, tensor)]
model_runner: vLLM ModelRunner

Returns:
None
"""
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, model_runner)
else:
model_runner.model.load_weights(weights=weights)

load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner)
load_model_weight_func = lambda x: self._load_model_weights(x, self.model_runner)

try:
packed_broadcast_consumer(
Expand Down
38 changes: 36 additions & 2 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,28 +965,62 @@ def calculate_size_in_bytes(param, tp_size, ep_size):
)
return param_info

def _is_fp8_weights_enabled(self) -> bool:
"""Check if the generation side uses FP8 weight quantization."""
if (
"generation" in self.cfg
and self.cfg["generation"] is not None
and self.cfg["generation"]["backend"] == "vllm"
):
vllm_cfg = self.cfg["generation"].get("vllm_cfg", {})
return vllm_cfg.get("precision") == "fp8"
return False

def _iter_params_with_optional_kv_scales(
self,
kv_scales: Optional[dict[str, float]] = None,
) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.

When FP8 weight quantization is enabled for generation, eligible weights
are quantized to FP8 on the training worker and yielded as
(name, fp8_data) + (name + "_scale_inv", scale) pairs. This reduces
the broadcast payload from bf16 to fp8.

This helper is used by both IPC-based streaming and collective broadcast
so that the logic for adding KV scales stays consistent in one place.
"""
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
FP8_WEIGHT_BLOCK_SIZE,
cast_tensor_to_fp8_blockwise,
get_vllm_qkv_scale_names,
should_quantize_to_fp8,
)

use_fp8_weights = self._is_fp8_weights_enabled()
use_pow2_scale = False
if use_fp8_weights:
vllm_cfg = self.cfg["generation"].get("vllm_cfg", {})
use_pow2_scale = vllm_cfg.get("pow2_weight_scaling_factors", False)

base_iter = self.megatron_bridge.export_hf_weights(
[self.model],
show_progress=False,
conversion_tasks=self.refit_conversion_tasks, # used for metadata caching
)

# Yield the original parameters first.
for name, tensor in base_iter:
yield name, tensor
if use_fp8_weights and should_quantize_to_fp8(name, tensor):
fp8_data, scale = cast_tensor_to_fp8_blockwise(
tensor.to(torch.float),
weight_block_size=FP8_WEIGHT_BLOCK_SIZE,
use_pow2_scale=use_pow2_scale,
)
scale = torch.squeeze(scale, dim=-1)
yield name, fp8_data
yield name + "_scale_inv", scale
else:
yield name, tensor

# Check whether FP8 KV cache is enabled.
use_fp8_kv_cache = False
Expand Down
Loading