-
Notifications
You must be signed in to change notification settings - Fork 337
FP8 Refit Optimization #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jianbing-D
wants to merge
2
commits into
NVIDIA-NeMo:main
Choose a base branch
from
Jianbing-D:grpo_fp8_refit_opt
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
FP8 Refit Optimization #2037
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,98 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
|
|
||
| FP8_WEIGHT_BLOCK_SIZE = [128, 128] | ||
|
|
||
|
|
||
| def should_quantize_to_fp8(name: str, tensor: torch.Tensor) -> bool: | ||
| """Check whether a HuggingFace-named weight should be block-quantized to FP8. | ||
|
|
||
| Matches the same set of parameters that vLLM quantizes (linear-layer | ||
| weights only). Embeddings, layernorms, biases, and lm_head are excluded. | ||
| """ | ||
| if tensor.dim() != 2: | ||
| return False | ||
| if not name.endswith(".weight"): | ||
| return False | ||
| lower = name.lower() | ||
| if any(kw in lower for kw in ("norm", "embed", "lm_head")): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def cast_tensor_to_fp8_blockwise( | ||
| data_hp: torch.Tensor, | ||
| weight_block_size: list[int], | ||
| use_pow2_scale: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Block-wise FP8 (E4M3) quantization — standalone, no vLLM dependencies. | ||
|
|
||
| Args: | ||
| data_hp: 2-D high-precision weight tensor (any float dtype). | ||
| weight_block_size: [block_rows, block_cols], e.g. [128, 128]. | ||
| use_pow2_scale: If True, round scale factors to powers of two. | ||
|
|
||
| Returns: | ||
| (fp8_data, descale) where fp8_data has dtype float8_e4m3fn and | ||
| descale is float32 with shape (blk_m, blk_n, 1). | ||
| """ | ||
| assert len(data_hp.shape) == 2, "Only 2-D input tensor is supported" | ||
|
|
||
| block_size0, block_size1 = weight_block_size | ||
| shape_before_padding = data_hp.shape | ||
|
|
||
| if data_hp.shape[0] % block_size0 != 0 or data_hp.shape[1] % block_size1 != 0: | ||
| pad0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0 | ||
| pad1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1 | ||
| data_hp = torch.nn.functional.pad( | ||
| data_hp, (0, pad1, 0, pad0), mode="constant", value=data_hp[-1, -1] | ||
| ) | ||
|
|
||
| max_dtype = torch.finfo(torch.float8_e4m3fn).max | ||
| original_shape = data_hp.shape | ||
| blk_m = data_hp.shape[0] // block_size0 | ||
| blk_n = data_hp.shape[1] // block_size1 | ||
|
|
||
| assert block_size0 == block_size1 | ||
| data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) | ||
| data_hp = data_hp.permute(0, 2, 1, 3) | ||
| data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) | ||
|
|
||
| max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) | ||
|
|
||
| if use_pow2_scale: | ||
| descale = max_abs / max_dtype | ||
| exponent = torch.ceil(torch.log2(descale)) | ||
| exponent = torch.clamp(exponent, min=-127, max=127) + 127 | ||
| exponent = exponent.to(torch.uint8) | ||
| scale_fp = torch.where( | ||
| exponent == 0, | ||
| 1.0, | ||
| torch.exp2(127 - exponent.to(torch.float32)), | ||
| ) | ||
| descale_fp = torch.reciprocal(scale_fp) | ||
| else: | ||
| scale_fp = max_dtype / max_abs | ||
| scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) | ||
| scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) | ||
| descale_fp = torch.reciprocal(scale_fp) | ||
|
Comment on lines
+71
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NaN values are not handled in scale computation. The linear scale path handles 🛡️ Suggested fix to handle NaN else:
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)
+ scale_fp = torch.where(torch.isnan(max_abs), 1.0, scale_fp)
descale_fp = torch.reciprocal(scale_fp)🤖 Prompt for AI Agents |
||
|
|
||
| data_lp = torch.clamp(data_hp * scale_fp, min=-max_dtype, max=max_dtype) | ||
| fp_data = data_lp.to(torch.float8_e4m3fn) | ||
|
|
||
| fp_data = ( | ||
| fp_data.reshape(blk_m, blk_n, block_size0, block_size1) | ||
| .permute(0, 2, 1, 3) | ||
| .reshape(original_shape) | ||
| ) | ||
|
|
||
| if original_shape != shape_before_padding: | ||
| fp_data = fp_data[: shape_before_padding[0], : shape_before_padding[1]] | ||
|
|
||
| return fp_data, descale_fp | ||
|
|
||
|
|
||
| def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: | ||
| """Get vLLM-compatible parameter names for Q/K/V FP8 scales. | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a bit too hacky; is it possible to obtain the list of param names to-be-quantized from the is_fp8_weight function in vllm side? This info can be synced one time and reused for all consequent steps