Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,17 @@ def setup(
)

if student_generation is not None:
# First pass — bf16, to get HF-style param names
state_dict_info = student_policy.prepare_refit_info()

# Sync FP8 param names from generation to training (one-time).
fp8_param_names = student_generation.get_fp8_param_names(
list(state_dict_info.keys())
)
if fp8_param_names:
student_policy.set_fp8_param_names(fp8_param_names)
state_dict_info = student_policy.prepare_refit_info()

student_generation.prepare_refit_info(state_dict_info)

# if it is not colocated inference, initialize collective communication for update weights
Expand Down
15 changes: 14 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,22 @@ def initialize_generation_with_policy(
ray.get(futures_train + futures_inference)
worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0

# prepare refit info
# prepare refit info (first pass — bf16, no FP8 quantization yet)
state_dict_info = policy.prepare_refit_info()

# Sync FP8 param names from generation (vLLM) to training (Megatron).
# vLLM's model structure is the single source of truth for which weights
# are FP8-quantized. state_dict_info keys are HF-style names produced by
# export_hf_weights — these are the names vLLM's _is_fp8_weight() expects.
if policy_generation is not None:
fp8_param_names = policy_generation.get_fp8_param_names(
list(state_dict_info.keys())
)
if fp8_param_names:
policy.set_fp8_param_names(fp8_param_names)
# Re-generate with FP8 quantization applied to eligible weights
state_dict_info = policy.prepare_refit_info()

policy_generation.prepare_refit_info(state_dict_info)

# Calculate total setup time
Expand Down
7 changes: 7 additions & 0 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ def requires_kv_scale_sync(self) -> bool:
"""Whether the generation backend requires KV cache scales synchronization."""
return False

def get_fp8_param_names(self, param_names: list[str]) -> set[str]:
"""Classify which HF param names are FP8-quantized by the generation backend.

Returns an empty set when FP8 is not enabled.
"""
return set()

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
raise NotImplementedError
Expand Down
20 changes: 20 additions & 0 deletions nemo_rl/models/generation/vllm/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,26 @@ def _is_fp8_weight(name, model):
return name in fp8_state.fp8_param_names


def get_fp8_param_names(param_names: list[str], model_runner) -> set[str]:
"""Classify which HF param names correspond to FP8-quantized weights.

Uses vLLM's model structure (LinearBase/FusedMoE with FP8 dtype) as the
single source of truth. The result is cached in ``fp8_state`` so that
subsequent ``_is_fp8_weight`` calls during refit are free.

Args:
param_names: HF-style parameter names (e.g. from export_hf_weights).
model_runner: vLLM ModelRunner with a loaded model.

Returns:
Set of param names that should be block-quantized to FP8.
"""
model = model_runner.model
for name in param_names:
_is_fp8_weight(name, model)
return set(fp8_state.fp8_param_names)


def load_weights(weights, model_runner):
weights_quantized = []
model = model_runner.model
Expand Down
76 changes: 76 additions & 0 deletions nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,82 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

FP8_WEIGHT_BLOCK_SIZE = [128, 128]


def cast_tensor_to_fp8_blockwise(
data_hp: torch.Tensor,
weight_block_size: list[int],
use_pow2_scale: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Block-wise FP8 (E4M3) quantization — standalone, no vLLM dependencies.

Args:
data_hp: 2-D high-precision weight tensor (any float dtype).
weight_block_size: [block_rows, block_cols], e.g. [128, 128].
use_pow2_scale: If True, round scale factors to powers of two.

Returns:
(fp8_data, descale) where fp8_data has dtype float8_e4m3fn and
descale is float32 with shape (blk_m, blk_n, 1).
"""
assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported"

block_size0, block_size1 = weight_block_size
shape_before_padding = data_hp.shape

if data_hp.shape[0] % block_size0 != 0 or data_hp.shape[1] % block_size1 != 0:
pad0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0
pad1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1
data_hp = torch.nn.functional.pad(
data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1]
)

max_dtype = torch.finfo(torch.float8_e4m3fn).max
original_shape = data_hp.shape
blk_m = data_hp.shape[0] // block_size0
blk_n = data_hp.shape[1] // block_size1

assert block_size0 == block_size1
data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1)
data_hp = data_hp.permute(0, 2, 1, 3)
data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2)

max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True)

if use_pow2_scale:
descale = max_abs / max_dtype
exponent = torch.ceil(torch.log2(descale))
exponent = torch.clamp(exponent, min=-127, max=127) + 127
exponent = exponent.to(torch.uint8)
scale_fp = torch.where(
exponent == 0,
1.0,
torch.exp2(127 - exponent.to(torch.float32)),
)
descale_fp = torch.reciprocal(scale_fp)
else:
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
descale_fp = torch.reciprocal(scale_fp)
Comment on lines +71 to +75
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

NaN values are not handled in scale computation.

The linear scale path handles max_abs == 0 and max_abs == inf, but NaN values would propagate silently. If any block contains NaN, both scale_fp and the resulting fp8_data would be NaN.

🛡️ Suggested fix to handle NaN
     else:
         scale_fp = max_dtype / max_abs
         scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
         scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
+        scale_fp = torch.where(torch.isnan(max_abs), 1.0, scale_fp)
         descale_fp = torch.reciprocal(scale_fp)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines
87 - 91, The scale computation doesn't guard against NaN in max_abs which will
produce NaN scale_fp and downstream fp8_data; update the branch that computes
scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.


data_lp = torch.clamp(data_hp * scale_fp, min=-max_dtype, max=max_dtype)
fp_data = data_lp.to(torch.float8_e4m3fn)

fp_data = (
fp_data.reshape(blk_m, blk_n, block_size0, block_size1)
.permute(0, 2, 1, 3)
.reshape(original_shape)
)

if original_shape != shape_before_padding:
fp_data = fp_data[: shape_before_padding[0], : shape_before_padding[1]]

return fp_data, descale_fp


def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]:
"""Get vLLM-compatible parameter names for Q/K/V FP8 scales.
Expand Down
62 changes: 35 additions & 27 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ def maybe_init_zmq(self):
self.zmq_socket.setsockopt(zmq.LINGER, 0)
self.zmq_socket.connect(self.get_zmq_address())

def get_fp8_param_names(self, param_names: list[str]) -> set[str]:
"""Classify which HF param names are FP8-quantized using vLLM's model.

This is the authoritative source of truth — it inspects the actual vLLM
model structure rather than relying on name-matching heuristics.
"""
from nemo_rl.models.generation.vllm.quantization import fp8

if not fp8.is_fp8_model(self.model_runner.vllm_config):
return set()
return fp8.get_fp8_param_names(param_names, self.model_runner)

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare state dict metadata for weight refitting and IPC streaming.

Expand All @@ -92,6 +104,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
e.g. {tensor_name: (shape, dtype)}
"""
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
self._weights_pre_quantized = any( # pyrefly: ignore[implicitly-defined-attribute]
k.endswith("_scale_inv") for k in state_dict_info
)

def _maybe_process_fp8_kv_cache(self) -> None:
"""Process weights after loading for FP8 KV cache (static scales)."""
Expand Down Expand Up @@ -173,14 +188,9 @@ def update_weights_via_ipc_zmq(self) -> bool:
assert offset == used_bytes, (
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
)
# Load weights into the model
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=weights)
# Load weights into the model (re-uses the shared helper that
# skips FP8 quantization when weights arrive pre-quantized).
self._load_model_weights(weights)

torch.cuda.current_stream().synchronize()

Expand All @@ -207,6 +217,22 @@ def update_weights_via_ipc_zmq(self) -> bool:
)
return False

def _load_model_weights(self, weights):
"""Load model weights, skipping FP8 quantization when pre-quantized.

Uses the ``_weights_pre_quantized`` flag cached at ``prepare_refit_info``
time instead of scanning every batch for ``_scale_inv`` entries.
"""
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
if self._weights_pre_quantized:
self.model_runner.model.load_weights(weights=weights)
else:
fp8.load_weights(weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=weights)

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_collective"
)
Expand All @@ -217,25 +243,7 @@ def update_weights_from_collective(self) -> bool:
"Please call prepare_refit_info when initializing the worker."
)

def _load_model_weights(weights, model_runner):
"""Load model weights.

Args:
weights: List[(name, tensor)]
model_runner: vLLM ModelRunner

Returns:
None
"""
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, model_runner)
else:
model_runner.model.load_weights(weights=weights)

load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner)
load_model_weight_func = lambda x: self._load_model_weights(x)

try:
packed_broadcast_consumer(
Expand Down
15 changes: 15 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,21 @@ def shutdown(self) -> bool:
print(f"Error during policy shutdown: {e}")
return False

def get_fp8_param_names(self, param_names: list[str]) -> set[str]:
"""Classify which HF param names are FP8-quantized using vLLM's model."""
method_name = (
"get_fp8_param_names_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "get_fp8_param_names"
)
futures = self.worker_group.run_all_workers_single_data(
method_name,
param_names=param_names,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
results = ray.get(futures)
return results[0]

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
# Choose the appropriate method based on async_engine setting
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,11 @@ def report_device_id(self) -> list[str]:
)
return cast(list[str], list_of_worker_results)

def get_fp8_param_names(self, param_names: list[str]) -> set[str]:
"""Classify which HF param names are FP8-quantized using vLLM's model."""
results = self.llm.collective_rpc("get_fp8_param_names", args=(param_names,))
return results[0]

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,11 @@ async def report_device_id_async(self) -> list[str]:

return cast(list[str], list_of_worker_results)

async def get_fp8_param_names_async(self, param_names: list[str]) -> set[str]:
"""Async version of get_fp8_param_names."""
results = await self.llm.collective_rpc("get_fp8_param_names", args=(param_names,))
return results[0]

async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None:
"""Async version of prepare_refit_info."""
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))
Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/models/policy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def offload_before_refit(self) -> None:
def offload_after_refit(self) -> None:
pass

def set_fp8_param_names(self, fp8_param_names: set[str]) -> None:
"""Cache the set of FP8-quantized param names obtained from the generation backend."""
pass

@abstractmethod
def prepare_refit_info(self) -> Optional[dict[str, Any]]:
pass
Expand Down
7 changes: 7 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,13 @@ def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool:
# We don't need to do anything here
return True

def set_fp8_param_names(self, fp8_param_names: set[str]) -> None:
"""Broadcast the FP8 param name set to all workers."""
futures = self.worker_group.run_all_workers_single_data(
"set_fp8_param_names", fp8_param_names=fp8_param_names
)
ray.get(futures)

def prepare_refit_info(self) -> Optional[dict[str, Any]]:
"""Prepare the info for refit.

Expand Down
33 changes: 30 additions & 3 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,11 +900,16 @@ def generate(

return BatchedDataDict.from_batches([out_dict]).to("cpu")

def set_fp8_param_names(self, fp8_param_names: set[str]) -> None:
"""Cache the set of FP8-quantized param names obtained from vLLM."""
self._fp8_param_names = fp8_param_names

@torch.no_grad()
@wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info")
def prepare_refit_info(self) -> None:
"""Prepare state dict metadata for weight refitting and IPC streaming."""
self.refit_param_info_mcore = self._calculate_refit_param_info()
if not hasattr(self, "refit_param_info_mcore") or self.refit_param_info_mcore is None:
self.refit_param_info_mcore = self._calculate_refit_param_info()

# Collect tensor metadata for refit / hf side info
refit_param_info_hf = {}
Expand Down Expand Up @@ -971,22 +976,44 @@ def _iter_params_with_optional_kv_scales(
) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.

When ``_fp8_param_names`` is populated (synced once from vLLM at init),
eligible weights are quantized to FP8 on the training worker and yielded
as (name, fp8_data) + (name + "_scale_inv", scale) pairs. This reduces
the broadcast payload from bf16 to fp8.

This helper is used by both IPC-based streaming and collective broadcast
so that the logic for adding KV scales stays consistent in one place.
"""
from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import (
FP8_WEIGHT_BLOCK_SIZE,
cast_tensor_to_fp8_blockwise,
get_vllm_qkv_scale_names,
)

fp8_param_names = getattr(self, "_fp8_param_names", set())
use_pow2_scale = False
if fp8_param_names:
vllm_cfg = self.cfg["generation"].get("vllm_cfg", {})
use_pow2_scale = vllm_cfg.get("pow2_weight_scaling_factors", False)

base_iter = self.megatron_bridge.export_hf_weights(
[self.model],
show_progress=False,
conversion_tasks=self.refit_conversion_tasks, # used for metadata caching
)

# Yield the original parameters first.
for name, tensor in base_iter:
yield name, tensor
if name in fp8_param_names:
fp8_data, scale = cast_tensor_to_fp8_blockwise(
tensor.to(torch.float),
weight_block_size=FP8_WEIGHT_BLOCK_SIZE,
use_pow2_scale=use_pow2_scale,
)
scale = torch.squeeze(scale, dim=-1)
yield name, fp8_data
yield name + "_scale_inv", scale
else:
yield name, tensor

# Check whether FP8 KV cache is enabled.
use_fp8_kv_cache = False
Expand Down
Loading