From a4c6b0613d8c503f0bb07a3e57d7883f4cfb77c0 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Sat, 28 Feb 2026 06:49:45 +0000 Subject: [PATCH 1/2] fp8 refit opt Signed-off-by: Jianbing Dong --- .../vllm/quantization/fp8_train_utils.py | 92 +++++++++++++++++++ .../models/generation/vllm/vllm_backend.py | 50 +++++----- .../policy/workers/megatron_policy_worker.py | 38 +++++++- 3 files changed, 151 insertions(+), 29 deletions(-) diff --git a/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py b/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py index ac4db666cf..61621f04b9 100644 --- a/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py +++ b/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py @@ -12,6 +12,98 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + +FP8_WEIGHT_BLOCK_SIZE = [128, 128] + + +def should_quantize_to_fp8(name: str, tensor: torch.Tensor) -> bool: + """Check whether a HuggingFace-named weight should be block-quantized to FP8. + + Matches the same set of parameters that vLLM quantizes (linear-layer + weights only). Embeddings, layernorms, biases, and lm_head are excluded. + """ + if tensor.dim() != 2: + return False + if not name.endswith(".weight"): + return False + lower = name.lower() + if any(kw in lower for kw in ("norm", "embed", "lm_head")): + return False + return True + + +def cast_tensor_to_fp8_blockwise( + data_hp: torch.Tensor, + weight_block_size: list[int], + use_pow2_scale: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Block-wise FP8 (E4M3) quantization — standalone, no vLLM dependencies. + + Args: + data_hp: 2-D high-precision weight tensor (any float dtype). + weight_block_size: [block_rows, block_cols], e.g. [128, 128]. + use_pow2_scale: If True, round scale factors to powers of two. + + Returns: + (fp8_data, descale) where fp8_data has dtype float8_e4m3fn and + descale is float32 with shape (blk_m, blk_n, 1). + """ + assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported" + + block_size0, block_size1 = weight_block_size + shape_before_padding = data_hp.shape + + if data_hp.shape[0] % block_size0 != 0 or data_hp.shape[1] % block_size1 != 0: + pad0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0 + pad1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1 + data_hp = torch.nn.functional.pad( + data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1] + ) + + max_dtype = torch.finfo(torch.float8_e4m3fn).max + original_shape = data_hp.shape + blk_m = data_hp.shape[0] // block_size0 + blk_n = data_hp.shape[1] // block_size1 + + assert block_size0 == block_size1 + data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) + data_hp = data_hp.permute(0, 2, 1, 3) + data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) + + max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) + + if use_pow2_scale: + descale = max_abs / max_dtype + exponent = torch.ceil(torch.log2(descale)) + exponent = torch.clamp(exponent, min=-127, max=127) + 127 + exponent = exponent.to(torch.uint8) + scale_fp = torch.where( + exponent == 0, + 1.0, + torch.exp2(127 - exponent.to(torch.float32)), + ) + descale_fp = torch.reciprocal(scale_fp) + else: + scale_fp = max_dtype / max_abs + scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) + scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) + descale_fp = torch.reciprocal(scale_fp) + + data_lp = torch.clamp(data_hp * scale_fp, min=-max_dtype, max=max_dtype) + fp_data = data_lp.to(torch.float8_e4m3fn) + + fp_data = ( + fp_data.reshape(blk_m, blk_n, block_size0, block_size1) + .permute(0, 2, 1, 3) + .reshape(original_shape) + ) + + if original_shape != shape_before_padding: + fp_data = fp_data[: shape_before_padding[0], : shape_before_padding[1]] + + return fp_data, descale_fp + def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: """Get vLLM-compatible parameter names for Q/K/V FP8 scales. diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 5d239fd902..6099d526fe 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -173,14 +173,9 @@ def update_weights_via_ipc_zmq(self) -> bool: assert offset == used_bytes, ( "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" ) - # Load weights into the model - from nemo_rl.models.generation.vllm.quantization import fp8 - - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) - else: - self.model_runner.model.load_weights(weights=weights) + # Load weights into the model (re-uses the shared helper that + # skips FP8 quantization when weights arrive pre-quantized). + self._load_model_weights(weights, self.model_runner) torch.cuda.current_stream().synchronize() @@ -207,6 +202,25 @@ def update_weights_via_ipc_zmq(self) -> bool: ) return False + @staticmethod + def _load_model_weights(weights, model_runner): + """Load model weights, skipping FP8 quantization when pre-quantized. + + If the incoming batch already contains ``_scale_inv`` entries (i.e. + weights were quantized on the training worker), we feed them straight + to ``model.load_weights`` and avoid a redundant quantization pass. + """ + from nemo_rl.models.generation.vllm.quantization import fp8 + + if fp8.is_fp8_model(model_runner.vllm_config): + pre_quantized = any(name.endswith("_scale_inv") for name, _ in weights) + if pre_quantized: + model_runner.model.load_weights(weights=weights) + else: + fp8.load_weights(weights, model_runner) + else: + model_runner.model.load_weights(weights=weights) + @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" ) @@ -217,25 +231,7 @@ def update_weights_from_collective(self) -> bool: "Please call prepare_refit_info when initializing the worker." ) - def _load_model_weights(weights, model_runner): - """Load model weights. - - Args: - weights: List[(name, tensor)] - model_runner: vLLM ModelRunner - - Returns: - None - """ - from nemo_rl.models.generation.vllm.quantization import fp8 - - if fp8.is_fp8_model(model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, model_runner) - else: - model_runner.model.load_weights(weights=weights) - - load_model_weight_func = lambda x: _load_model_weights(x, self.model_runner) + load_model_weight_func = lambda x: self._load_model_weights(x, self.model_runner) try: packed_broadcast_consumer( diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 5f1483ed9a..c8dc70da92 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -965,28 +965,62 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) return param_info + def _is_fp8_weights_enabled(self) -> bool: + """Check if the generation side uses FP8 weight quantization.""" + if ( + "generation" in self.cfg + and self.cfg["generation"] is not None + and self.cfg["generation"]["backend"] == "vllm" + ): + vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) + return vllm_cfg.get("precision") == "fp8" + return False + def _iter_params_with_optional_kv_scales( self, kv_scales: Optional[dict[str, float]] = None, ) -> Iterator[tuple[str, torch.Tensor]]: """Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. + When FP8 weight quantization is enabled for generation, eligible weights + are quantized to FP8 on the training worker and yielded as + (name, fp8_data) + (name + "_scale_inv", scale) pairs. This reduces + the broadcast payload from bf16 to fp8. + This helper is used by both IPC-based streaming and collective broadcast so that the logic for adding KV scales stays consistent in one place. """ from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import ( + FP8_WEIGHT_BLOCK_SIZE, + cast_tensor_to_fp8_blockwise, get_vllm_qkv_scale_names, + should_quantize_to_fp8, ) + use_fp8_weights = self._is_fp8_weights_enabled() + use_pow2_scale = False + if use_fp8_weights: + vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) + use_pow2_scale = vllm_cfg.get("pow2_weight_scaling_factors", False) + base_iter = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # Yield the original parameters first. for name, tensor in base_iter: - yield name, tensor + if use_fp8_weights and should_quantize_to_fp8(name, tensor): + fp8_data, scale = cast_tensor_to_fp8_blockwise( + tensor.to(torch.float), + weight_block_size=FP8_WEIGHT_BLOCK_SIZE, + use_pow2_scale=use_pow2_scale, + ) + scale = torch.squeeze(scale, dim=-1) + yield name, fp8_data + yield name + "_scale_inv", scale + else: + yield name, tensor # Check whether FP8 KV cache is enabled. use_fp8_kv_cache = False From 96d673faadaadb4722c07964814db1ae295b46c8 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Fri, 10 Apr 2026 01:39:01 +0000 Subject: [PATCH 2/2] sync fp8 param name list once Signed-off-by: Jianbing Dong --- nemo_rl/algorithms/distillation.py | 10 +++++ nemo_rl/algorithms/grpo.py | 15 +++++++- nemo_rl/models/generation/interfaces.py | 7 ++++ .../generation/vllm/quantization/fp8.py | 20 ++++++++++ .../vllm/quantization/fp8_train_utils.py | 16 -------- .../models/generation/vllm/vllm_backend.py | 38 ++++++++++++------- .../models/generation/vllm/vllm_generation.py | 15 ++++++++ nemo_rl/models/generation/vllm/vllm_worker.py | 5 +++ .../generation/vllm/vllm_worker_async.py | 5 +++ nemo_rl/models/policy/interfaces.py | 4 ++ nemo_rl/models/policy/lm_policy.py | 7 ++++ .../policy/workers/megatron_policy_worker.py | 31 ++++++--------- 12 files changed, 124 insertions(+), 49 deletions(-) diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 6fa9689d1a..64d879ded1 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -458,7 +458,17 @@ def setup( ) if student_generation is not None: + # First pass — bf16, to get HF-style param names state_dict_info = student_policy.prepare_refit_info() + + # Sync FP8 param names from generation to training (one-time). + fp8_param_names = student_generation.get_fp8_param_names( + list(state_dict_info.keys()) + ) + if fp8_param_names: + student_policy.set_fp8_param_names(fp8_param_names) + state_dict_info = student_policy.prepare_refit_info() + student_generation.prepare_refit_info(state_dict_info) # if it is not colocated inference, initialize collective communication for update weights diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6772739655..253ca4fa5d 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -700,9 +700,22 @@ def initialize_generation_with_policy( ray.get(futures_train + futures_inference) worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0 - # prepare refit info + # prepare refit info (first pass — bf16, no FP8 quantization yet) state_dict_info = policy.prepare_refit_info() + + # Sync FP8 param names from generation (vLLM) to training (Megatron). + # vLLM's model structure is the single source of truth for which weights + # are FP8-quantized. state_dict_info keys are HF-style names produced by + # export_hf_weights — these are the names vLLM's _is_fp8_weight() expects. if policy_generation is not None: + fp8_param_names = policy_generation.get_fp8_param_names( + list(state_dict_info.keys()) + ) + if fp8_param_names: + policy.set_fp8_param_names(fp8_param_names) + # Re-generate with FP8 quantization applied to eligible weights + state_dict_info = policy.prepare_refit_info() + policy_generation.prepare_refit_info(state_dict_info) # Calculate total setup time diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 037b4880f5..f6c01bb6be 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -248,6 +248,13 @@ def requires_kv_scale_sync(self) -> bool: """Whether the generation backend requires KV cache scales synchronization.""" return False + def get_fp8_param_names(self, param_names: list[str]) -> set[str]: + """Classify which HF param names are FP8-quantized by the generation backend. + + Returns an empty set when FP8 is not enabled. + """ + return set() + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm/quantization/fp8.py b/nemo_rl/models/generation/vllm/quantization/fp8.py index 9505f42524..dc7eb8a90f 100644 --- a/nemo_rl/models/generation/vllm/quantization/fp8.py +++ b/nemo_rl/models/generation/vllm/quantization/fp8.py @@ -347,6 +347,26 @@ def _is_fp8_weight(name, model): return name in fp8_state.fp8_param_names +def get_fp8_param_names(param_names: list[str], model_runner) -> set[str]: + """Classify which HF param names correspond to FP8-quantized weights. + + Uses vLLM's model structure (LinearBase/FusedMoE with FP8 dtype) as the + single source of truth. The result is cached in ``fp8_state`` so that + subsequent ``_is_fp8_weight`` calls during refit are free. + + Args: + param_names: HF-style parameter names (e.g. from export_hf_weights). + model_runner: vLLM ModelRunner with a loaded model. + + Returns: + Set of param names that should be block-quantized to FP8. + """ + model = model_runner.model + for name in param_names: + _is_fp8_weight(name, model) + return set(fp8_state.fp8_param_names) + + def load_weights(weights, model_runner): weights_quantized = [] model = model_runner.model diff --git a/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py b/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py index 61621f04b9..6640328d6c 100644 --- a/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py +++ b/nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py @@ -17,22 +17,6 @@ FP8_WEIGHT_BLOCK_SIZE = [128, 128] -def should_quantize_to_fp8(name: str, tensor: torch.Tensor) -> bool: - """Check whether a HuggingFace-named weight should be block-quantized to FP8. - - Matches the same set of parameters that vLLM quantizes (linear-layer - weights only). Embeddings, layernorms, biases, and lm_head are excluded. - """ - if tensor.dim() != 2: - return False - if not name.endswith(".weight"): - return False - lower = name.lower() - if any(kw in lower for kw in ("norm", "embed", "lm_head")): - return False - return True - - def cast_tensor_to_fp8_blockwise( data_hp: torch.Tensor, weight_block_size: list[int], diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 6099d526fe..00e990b7d8 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -84,6 +84,18 @@ def maybe_init_zmq(self): self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.get_zmq_address()) + def get_fp8_param_names(self, param_names: list[str]) -> set[str]: + """Classify which HF param names are FP8-quantized using vLLM's model. + + This is the authoritative source of truth — it inspects the actual vLLM + model structure rather than relying on name-matching heuristics. + """ + from nemo_rl.models.generation.vllm.quantization import fp8 + + if not fp8.is_fp8_model(self.model_runner.vllm_config): + return set() + return fp8.get_fp8_param_names(param_names, self.model_runner) + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare state dict metadata for weight refitting and IPC streaming. @@ -92,6 +104,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: e.g. {tensor_name: (shape, dtype)} """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self._weights_pre_quantized = any( # pyrefly: ignore[implicitly-defined-attribute] + k.endswith("_scale_inv") for k in state_dict_info + ) def _maybe_process_fp8_kv_cache(self) -> None: """Process weights after loading for FP8 KV cache (static scales).""" @@ -175,7 +190,7 @@ def update_weights_via_ipc_zmq(self) -> bool: ) # Load weights into the model (re-uses the shared helper that # skips FP8 quantization when weights arrive pre-quantized). - self._load_model_weights(weights, self.model_runner) + self._load_model_weights(weights) torch.cuda.current_stream().synchronize() @@ -202,24 +217,21 @@ def update_weights_via_ipc_zmq(self) -> bool: ) return False - @staticmethod - def _load_model_weights(weights, model_runner): + def _load_model_weights(self, weights): """Load model weights, skipping FP8 quantization when pre-quantized. - If the incoming batch already contains ``_scale_inv`` entries (i.e. - weights were quantized on the training worker), we feed them straight - to ``model.load_weights`` and avoid a redundant quantization pass. + Uses the ``_weights_pre_quantized`` flag cached at ``prepare_refit_info`` + time instead of scanning every batch for ``_scale_inv`` entries. """ from nemo_rl.models.generation.vllm.quantization import fp8 - if fp8.is_fp8_model(model_runner.vllm_config): - pre_quantized = any(name.endswith("_scale_inv") for name, _ in weights) - if pre_quantized: - model_runner.model.load_weights(weights=weights) + if fp8.is_fp8_model(self.model_runner.vllm_config): + if self._weights_pre_quantized: + self.model_runner.model.load_weights(weights=weights) else: - fp8.load_weights(weights, model_runner) + fp8.load_weights(weights, self.model_runner) else: - model_runner.model.load_weights(weights=weights) + self.model_runner.model.load_weights(weights=weights) @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" @@ -231,7 +243,7 @@ def update_weights_from_collective(self) -> bool: "Please call prepare_refit_info when initializing the worker." ) - load_model_weight_func = lambda x: self._load_model_weights(x, self.model_runner) + load_model_weight_func = lambda x: self._load_model_weights(x) try: packed_broadcast_consumer( diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 6138dfdb43..b4492ec946 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -810,6 +810,21 @@ def shutdown(self) -> bool: print(f"Error during policy shutdown: {e}") return False + def get_fp8_param_names(self, param_names: list[str]) -> set[str]: + """Classify which HF param names are FP8-quantized using vLLM's model.""" + method_name = ( + "get_fp8_param_names_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "get_fp8_param_names" + ) + futures = self.worker_group.run_all_workers_single_data( + method_name, + param_names=param_names, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + results = ray.get(futures) + return results[0] + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" # Choose the appropriate method based on async_engine setting diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 2237a9efde..d8fc2e651b 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -798,6 +798,11 @@ def report_device_id(self) -> list[str]: ) return cast(list[str], list_of_worker_results) + def get_fp8_param_names(self, param_names: list[str]) -> set[str]: + """Classify which HF param names are FP8-quantized using vLLM's model.""" + results = self.llm.collective_rpc("get_fp8_param_names", args=(param_names,)) + return results[0] + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0fd2b5c063..a3d50af5a9 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -997,6 +997,11 @@ async def report_device_id_async(self) -> list[str]: return cast(list[str], list_of_worker_results) + async def get_fp8_param_names_async(self, param_names: list[str]) -> set[str]: + """Async version of get_fp8_param_names.""" + results = await self.llm.collective_rpc("get_fp8_param_names", args=(param_names,)) + return results[0] + async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None: """Async version of prepare_refit_info.""" await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 464377c57a..fc39d456b9 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -178,6 +178,10 @@ def offload_before_refit(self) -> None: def offload_after_refit(self) -> None: pass + def set_fp8_param_names(self, fp8_param_names: set[str]) -> None: + """Cache the set of FP8-quantized param names obtained from the generation backend.""" + pass + @abstractmethod def prepare_refit_info(self) -> Optional[dict[str, Any]]: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 29f034b065..1c6d4d096f 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -728,6 +728,13 @@ def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool: # We don't need to do anything here return True + def set_fp8_param_names(self, fp8_param_names: set[str]) -> None: + """Broadcast the FP8 param name set to all workers.""" + futures = self.worker_group.run_all_workers_single_data( + "set_fp8_param_names", fp8_param_names=fp8_param_names + ) + ray.get(futures) + def prepare_refit_info(self) -> Optional[dict[str, Any]]: """Prepare the info for refit. diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index c8dc70da92..bd77cd0709 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -900,11 +900,16 @@ def generate( return BatchedDataDict.from_batches([out_dict]).to("cpu") + def set_fp8_param_names(self, fp8_param_names: set[str]) -> None: + """Cache the set of FP8-quantized param names obtained from vLLM.""" + self._fp8_param_names = fp8_param_names + @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: """Prepare state dict metadata for weight refitting and IPC streaming.""" - self.refit_param_info_mcore = self._calculate_refit_param_info() + if not hasattr(self, "refit_param_info_mcore") or self.refit_param_info_mcore is None: + self.refit_param_info_mcore = self._calculate_refit_param_info() # Collect tensor metadata for refit / hf side info refit_param_info_hf = {} @@ -965,26 +970,15 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) return param_info - def _is_fp8_weights_enabled(self) -> bool: - """Check if the generation side uses FP8 weight quantization.""" - if ( - "generation" in self.cfg - and self.cfg["generation"] is not None - and self.cfg["generation"]["backend"] == "vllm" - ): - vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) - return vllm_cfg.get("precision") == "fp8" - return False - def _iter_params_with_optional_kv_scales( self, kv_scales: Optional[dict[str, float]] = None, ) -> Iterator[tuple[str, torch.Tensor]]: """Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. - When FP8 weight quantization is enabled for generation, eligible weights - are quantized to FP8 on the training worker and yielded as - (name, fp8_data) + (name + "_scale_inv", scale) pairs. This reduces + When ``_fp8_param_names`` is populated (synced once from vLLM at init), + eligible weights are quantized to FP8 on the training worker and yielded + as (name, fp8_data) + (name + "_scale_inv", scale) pairs. This reduces the broadcast payload from bf16 to fp8. This helper is used by both IPC-based streaming and collective broadcast @@ -994,12 +988,11 @@ def _iter_params_with_optional_kv_scales( FP8_WEIGHT_BLOCK_SIZE, cast_tensor_to_fp8_blockwise, get_vllm_qkv_scale_names, - should_quantize_to_fp8, ) - use_fp8_weights = self._is_fp8_weights_enabled() + fp8_param_names = getattr(self, "_fp8_param_names", set()) use_pow2_scale = False - if use_fp8_weights: + if fp8_param_names: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) use_pow2_scale = vllm_cfg.get("pow2_weight_scaling_factors", False) @@ -1010,7 +1003,7 @@ def _iter_params_with_optional_kv_scales( ) for name, tensor in base_iter: - if use_fp8_weights and should_quantize_to_fp8(name, tensor): + if name in fp8_param_names: fp8_data, scale = cast_tensor_to_fp8_blockwise( tensor.to(torch.float), weight_block_size=FP8_WEIGHT_BLOCK_SIZE,