-
Notifications
You must be signed in to change notification settings - Fork 389
Support Mixed precision & Static MSE PTQ in MCore export; Nemotron Super v3 NVFP4 recipe #1363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jenchen13
wants to merge
15
commits into
main
Choose a base branch
from
jennifchen/super_nvfp4_recipe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
2a0c852
add nemotron super 4 nvfp4 recipe
jenchen13 9f96df3
remove latent moe fp8
jenchen13 9282cdb
fix docstring
jenchen13 cfaf055
fix MSE moe calibration and add stride for fp8 scale sweep
jenchen13 f796197
fix MLM naming in recipe and add stride recipe
jenchen13 81d9d87
fix merge conflict
jenchen13 acf6892
amax recipe
jenchen13 372820e
fix config
jenchen13 c635b8a
add amax recipe
jenchen13 a829722
mixed precision export for megatron
jenchen13 5e32bd1
fix mcore ckpt resume for static quantizers and MSE export
jenchen13 ef37456
remove duplicate recipe
jenchen13 b5f4b56
fix docstring
jenchen13 b5c5331
fix max calib recipe
jenchen13 5de5541
cleanup recipes
jenchen13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: | ||
| # - MoE routed experts (mixer.experts.<N>.{up,down}_proj): NVFP4 W4A4 weight MSE, group_size 16 | ||
| # - MoE shared experts (mixer.shared_experts.{up,down}_proj): FP8 per-tensor | ||
| # - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor | ||
| # - KV cache: FP8 | ||
| # - Attention linears ({q,k,v}_proj): BF16 (not quantized) | ||
| # - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) | ||
| # - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) | ||
| # - SSM cache: FP32 (can be set to FP16 in VLLM) | ||
| # | ||
| # Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values | ||
| # (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales | ||
| # are also chosen via MSE search instead of plain amax). | ||
| metadata: | ||
| recipe_type: ptq | ||
| description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj | ||
| FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with FP8 scale sweep. | ||
| quantize: | ||
| algorithm: | ||
| method: mse | ||
| fp8_scale_sweep: true | ||
| quant_cfg: | ||
| - quantizer_name: '*' | ||
| enable: false | ||
|
|
||
| # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. | ||
| # Weight uses static block scales (chosen by MSE); activations stay dynamic. | ||
| - quantizer_name: '*mixer.experts.*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: static | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*mixer.experts.*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: dynamic | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
|
|
||
| # MoE shared experts -> FP8 per-tensor. | ||
| - quantizer_name: '*mixer.shared_experts.*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.shared_experts.*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
|
|
||
| # Mamba mixer linears -> FP8 per-tensor. | ||
| - quantizer_name: '*mixer.in_proj*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.in_proj*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.out_proj*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.out_proj*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
|
|
||
| # latent MOE down/up projections) -> FP8 per-tensor. | ||
| # NOTE: only 3 layers quantized latent MOE to FP8, layers 1, 3, 5 | ||
| - quantizer_name: '*mixer.fc1_latent_proj*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.fc1_latent_proj*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.fc2_latent_proj*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
| - quantizer_name: '*mixer.fc2_latent_proj*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
| axis: | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| # KV cache -> FP8. | ||
| - quantizer_name: '*[kv]_bmm_quantizer' | ||
| enable: true | ||
| cfg: | ||
| num_bits: e4m3 | ||
|
|
||
| # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. | ||
| # SSM state / mamba conv1d stay FP16. | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.