Skip to content

FP8 Refit Optimization#2037

Open
Jianbing-D wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Jianbing-D:grpo_fp8_refit_opt
Open

FP8 Refit Optimization#2037
Jianbing-D wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Jianbing-D:grpo_fp8_refit_opt

Conversation

@Jianbing-D
Copy link
Copy Markdown

@Jianbing-D Jianbing-D commented Feb 28, 2026

What does this PR do ?

Quantize before weight transfer to accelerate refit in FP8 GRPO.
Describe here: https://nvbugspro.nvidia.com/bug/5863778

weights are periodically synced from the Megatron training worker to the vLLM generation worker. FP8 refit quantizes weights on the training side before broadcast, reducing network payload (BF16 -> FP8 + compact scales).

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Test results

Verified on llama3.1-8B.
https://wandb.ai/nv-default-onboard/nemo-rl/reports/-FP8-Refit-Optimization--VmlldzoxNjI2NTkwMg

Summary by CodeRabbit

Release Notes

  • New Features

    • Added FP8 quantization support for model weights using blockwise scaling to optimize memory usage.
    • Extended policy worker to enable automatic FP8 weight quantization during inference with configurable scaling options.
  • Refactor

    • Centralized weight loading logic to streamline FP8 and non-FP8 model handling.
    • Improved detection of pre-quantized weight batches for more efficient model loading.

Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
@terrykong terrykong requested a review from guyueh1 February 28, 2026 07:27
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guyueh1 to review

@guyueh1 guyueh1 added the Performance Related to improving performance label Mar 5, 2026
@guyueh1
Copy link
Copy Markdown
Contributor

guyueh1 commented Mar 19, 2026

@Jianbing-D can you change this to "ready to review" and fix the PR name?

@Jianbing-D Jianbing-D marked this pull request as ready for review March 20, 2026 09:07
@Jianbing-D Jianbing-D requested review from a team as code owners March 20, 2026 09:07
@Jianbing-D Jianbing-D changed the title fp8 refit opt FP8 Refit Optimization Mar 20, 2026
@Jianbing-D
Copy link
Copy Markdown
Author

Hi @guyueh1 , please review it.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

This PR adds FP8 weight quantization support to the generation backend. It introduces FP8 casting utilities for block-wise quantization, integrates them into the policy worker to quantize eligible weights during export, and refactors weight loading logic in the vLLM backend to handle pre-quantized batches.

Changes

Cohort / File(s) Summary
FP8 Quantization Utilities
nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py
Added FP8_WEIGHT_BLOCK_SIZE constant, should_quantize_to_fp8() filter function for HuggingFace-named weight tensors, and cast_tensor_to_fp8_blockwise() implementing block-wise E4M3 FP8 casting with optional pow2 scaling.
vLLM Backend Weight Loading
nemo_rl/models/generation/vllm/vllm_backend.py
Centralized FP8 vs non-FP8 weight loading into _load_model_weights() static method; added pre-quantized batch detection via _scale_inv tensor names to conditionally route to fp8.load_weights() or direct model loading.
Policy Worker FP8 Integration
nemo_rl/models/policy/workers/megatron_policy_worker.py
Added _is_fp8_weights_enabled() detection method; extended _iter_params_with_optional_kv_scales() to conditionally quantize eligible weights to FP8 and yield both quantized tensors and their inverse scale factors.

Sequence Diagram(s)

sequenceDiagram
    participant PW as MegatronPolicyWorker
    participant FP8 as fp8_train_utils
    participant Export as Export Process
    
    PW->>PW: _is_fp8_weights_enabled()
    alt FP8 Enabled
        PW->>PW: _iter_params_with_optional_kv_scales()
        loop For Each Parameter
            PW->>FP8: should_quantize_to_fp8(name, tensor)
            FP8-->>PW: bool (eligible 2D weights?)
            alt Eligible for Quantization
                PW->>FP8: cast_tensor_to_fp8_blockwise(tensor, block_size)
                FP8->>FP8: Pad to block multiples
                FP8->>FP8: Compute per-block scales
                FP8->>FP8: Cast to float8_e4m3fn
                FP8-->>PW: (fp8_data, scale_inv)
                PW->>Export: Yield (name, fp8_data)
                PW->>Export: Yield (name_scale_inv, scale)
            else Not Eligible
                PW->>Export: Yield (name, original_tensor)
            end
        end
    else FP8 Disabled
        PW->>Export: Yield parameters unchanged
    end
Loading
sequenceDiagram
    participant Backend as vllm_backend
    participant Config as vllm_config
    participant FP8Module as fp8 module
    participant Model as Model
    
    Backend->>Backend: update_weights_via_ipc_zmq(weights)
    Backend->>Backend: _load_model_weights(weights, model_runner)
    Backend->>Config: is_fp8_model(vllm_config)?
    alt FP8 Model
        Backend->>Backend: Detect pre-quantized (_scale_inv present)
        alt Pre-quantized Weights
            Backend->>Model: load_weights(weights)
            Model-->>Backend: Loaded
        else Non Pre-quantized
            Backend->>FP8Module: load_weights(weights, model_runner)
            FP8Module->>Model: Apply FP8 transformations
            FP8Module-->>Backend: Loaded
        end
    else Non-FP8 Model
        Backend->>Model: load_weights(weights)
        Model-->>Backend: Loaded
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR description lacks test results, performance metrics, and convergence validation despite introducing major numeric-changing FP8 quantization features. Add to PR description: test results summary, performance metrics with hardware details, convergence validation, and reference to existing docs/fp8.md documentation.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'FP8 Refit Optimization' directly aligns with the PR's main objective of introducing quantization before weight transfer to accelerate refit in FP8 GRPO.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
nemo_rl/models/generation/vllm/vllm_backend.py (1)

234-234: Replace lambda with local function per style guidelines.

Static analysis flagged E731: assigning a lambda expression. Use a def statement instead.

♻️ Suggested fix
-        load_model_weight_func = lambda x: self._load_model_weights(x, self.model_runner)
+        def load_model_weight_func(x):
+            self._load_model_weights(x, self.model_runner)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/vllm_backend.py` at line 234, Replace the
lambda assignment to load_model_weight_func with a local def function to satisfy
style rules: create a small local function (e.g., def
load_model_weight_func(path): return self._load_model_weights(path,
self.model_runner)) and use that function in place of the lambda; keep the same
name load_model_weight_func and ensure it calls self._load_model_weights with
the same arguments (path and self.model_runner) so behavior does not change.
nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py (3)

57-62: Consider zero-padding instead of edge-value padding.

Using data_hp[-1, -1] for padding could propagate anomalous values (NaN, Inf, or outliers) if they happen to be at the tensor edge. Zero-padding is more conventional and predictable for quantization.

🔧 Suggested fix
         data_hp = torch.nn.functional.pad(
-            data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1]
+            data_hp, (0, pad1, 0, pad0), mode="constant", value=0.0
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines
57 - 62, Replace edge-value padding with zero-padding in the block-alignment
logic: when computing pad0/pad1 and calling torch.nn.functional.pad on data_hp
(the block-padding branch that checks data_hp.shape[...] % block_size...),
change the pad fill from value=data_hp[-1, -1] to value=0 so that padding uses
zeros instead of the tensor's edge element; keep the same pad tuple order and
mode="constant" so only the fill value changes.

69-69: Undocumented square-block constraint.

The function signature accepts weight_block_size: list[int] suggesting arbitrary block dimensions, but this assertion enforces square blocks. Consider documenting this restriction in the docstring or simplifying the signature if non-square blocks are never intended.

📝 Document the restriction in docstring
     Args:
         data_hp: 2-D high-precision weight tensor (any float dtype).
-        weight_block_size: [block_rows, block_cols], e.g. [128, 128].
+        weight_block_size: [block_rows, block_cols], e.g. [128, 128].
+            Note: block_rows must equal block_cols (square blocks only).
         use_pow2_scale: If True, round scale factors to powers of two.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` at line 69,
The code currently asserts block_size0 == block_size1 (from the assert
block_size0 == block_size1) while the function accepts weight_block_size:
list[int], which implies non-square blocks are allowed; update the function’s
docstring (the function that takes weight_block_size) to state explicitly that
only square blocks are supported (i.e., weight_block_size must be [N, N]) or, if
non-square blocks are never intended, simplify the signature to accept a single
int for block_size and remove the list ambiguity; reference the
weight_block_size parameter and the assert block_size0 == block_size1 when
documenting the constraint.

52-52: Consider using explicit exception instead of assert for input validation.

Assertions can be disabled with Python's -O flag, which would bypass this validation in optimized runs. For production code, explicit exceptions are more reliable.

🔧 Suggested fix
-    assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported"
+    if len(data_hp.shape) != 2:
+        raise ValueError("Only 2-D input tensor is supported")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` at line 52,
Replace the runtime-only assert on data_hp with an explicit input validation
that always runs: check if len(data_hp.shape) != 2 and raise a ValueError with
the message "Only 2-D input tensor is supported"; update the check around the
data_hp usage in fp8_train_utils.py (the line currently doing `assert
len(data_hp.shape) == 2, ...`) to this explicit conditional to ensure validation
cannot be skipped under -O.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py`:
- Around line 87-91: The scale computation doesn't guard against NaN in max_abs
which will produce NaN scale_fp and downstream fp8_data; update the branch that
computes scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.

---

Nitpick comments:
In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py`:
- Around line 57-62: Replace edge-value padding with zero-padding in the
block-alignment logic: when computing pad0/pad1 and calling
torch.nn.functional.pad on data_hp (the block-padding branch that checks
data_hp.shape[...] % block_size...), change the pad fill from value=data_hp[-1,
-1] to value=0 so that padding uses zeros instead of the tensor's edge element;
keep the same pad tuple order and mode="constant" so only the fill value
changes.
- Line 69: The code currently asserts block_size0 == block_size1 (from the
assert block_size0 == block_size1) while the function accepts weight_block_size:
list[int], which implies non-square blocks are allowed; update the function’s
docstring (the function that takes weight_block_size) to state explicitly that
only square blocks are supported (i.e., weight_block_size must be [N, N]) or, if
non-square blocks are never intended, simplify the signature to accept a single
int for block_size and remove the list ambiguity; reference the
weight_block_size parameter and the assert block_size0 == block_size1 when
documenting the constraint.
- Line 52: Replace the runtime-only assert on data_hp with an explicit input
validation that always runs: check if len(data_hp.shape) != 2 and raise a
ValueError with the message "Only 2-D input tensor is supported"; update the
check around the data_hp usage in fp8_train_utils.py (the line currently doing
`assert len(data_hp.shape) == 2, ...`) to this explicit conditional to ensure
validation cannot be skipped under -O.

In `@nemo_rl/models/generation/vllm/vllm_backend.py`:
- Line 234: Replace the lambda assignment to load_model_weight_func with a local
def function to satisfy style rules: create a small local function (e.g., def
load_model_weight_func(path): return self._load_model_weights(path,
self.model_runner)) and use that function in place of the lambda; keep the same
name load_model_weight_func and ensure it calls self._load_model_weights with
the same arguments (path and self.model_runner) so behavior does not change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d17786db-f935-4e85-9778-fd41cce69b35

📥 Commits

Reviewing files that changed from the base of the PR and between 4a7aa47 and a4c6b06.

📒 Files selected for processing (3)
  • nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py
  • nemo_rl/models/generation/vllm/vllm_backend.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py

Comment on lines +87 to +91
else:
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
descale_fp = torch.reciprocal(scale_fp)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

NaN values are not handled in scale computation.

The linear scale path handles max_abs == 0 and max_abs == inf, but NaN values would propagate silently. If any block contains NaN, both scale_fp and the resulting fp8_data would be NaN.

🛡️ Suggested fix to handle NaN
     else:
         scale_fp = max_dtype / max_abs
         scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
         scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
+        scale_fp = torch.where(torch.isnan(max_abs), 1.0, scale_fp)
         descale_fp = torch.reciprocal(scale_fp)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py` around lines
87 - 91, The scale computation doesn't guard against NaN in max_abs which will
produce NaN scale_fp and downstream fp8_data; update the branch that computes
scale_fp/descale_fp to also replace NaNs (e.g., using torch.isnan or
torch.isfinite) with a safe fallback (1.0) before taking the reciprocal so that
scale_fp = max_dtype / max_abs is followed by handling max_abs == 0, max_abs ==
inf, and max_abs == NaN (set those scale entries to 1.0), then compute
descale_fp = torch.reciprocal(scale_fp); modify the existing scale_fp and
descale_fp logic where those symbols are defined to include the NaN check.

@anwithk anwithk added this to the v0.6 Release milestone Mar 20, 2026
if not name.endswith(".weight"):
return False
lower = name.lower()
if any(kw in lower for kw in ("norm", "embed", "lm_head")):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit too hacky; is it possible to obtain the list of param names to-be-quantized from the is_fp8_weight function in vllm side? This info can be synced one time and reused for all consequent steps

Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
@Jianbing-D Jianbing-D requested a review from a team as a code owner April 10, 2026 01:39
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Jianbing-D
Copy link
Copy Markdown
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Performance Related to improving performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants