diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 6ca99c7f96..693506929d 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -49,18 +49,7 @@ dense | sparsegpt) ;; ;; esac -#Iterate over list of qformats provided and check if they are valid -IFS="," -for qformat in $QFORMAT; do - case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;; - *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2 - exit 1 - ;; - esac -done -IFS=" " +# Quant format / recipe validation is delegated to hf_ptq.py. script_dir="$(dirname "$(readlink -f "$0")")" @@ -72,7 +61,14 @@ fi QFORMAT_MODIFIED="${QFORMAT//,/_}" -MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +# When using --recipe, build the model name from the recipe basename (without +# directory or .yaml suffix) so each recipe gets its own SAVE_PATH. +if [ -n "$RECIPE" ]; then + RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g') + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG} +else + MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} +fi SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME} @@ -177,11 +173,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH if [[ "$MODEL_CONFIG_EXIST" == false ]]; then echo "Quantizing original model..." + if [ -n "$RECIPE" ]; then + QUANT_SPEC_ARGS="--recipe=$RECIPE" + else + QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}" + fi python hf_ptq.py \ --pyt_ckpt_path=$MODEL_PATH \ --export_path=$SAVE_PATH \ --sparsity_fmt=$SPARSITY_FMT \ - --qformat="${QFORMAT// /,}" \ + $QUANT_SPEC_ARGS \ --calib_size=$CALIB_SIZE \ --batch_size=$CALIB_BATCH_SIZE \ --inference_tensor_parallel=$TP \ @@ -203,7 +204,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH exit 0 fi - if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then + if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1) if [ "$cuda_major" -lt 10 ]; then @@ -212,6 +213,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH fi fi + if [ -n "$RECIPE" ]; then + echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH" + exit 0 + fi + if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH" exit 0 diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3817c1dee7..2a9a28b356 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -20,6 +20,7 @@ parse_options() { # Default values MODEL_PATH="" QFORMAT="" + RECIPE="" KV_CACHE_QUANT="" TP=1 PP=1 @@ -37,13 +38,14 @@ parse_options() { CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do case "$1" in --model ) MODEL_PATH="$2"; shift 2;; --quant ) QFORMAT="$2"; shift 2;; + --recipe ) RECIPE="$2"; shift 2;; --kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;; --tp ) TP="$2"; shift 2;; --pp ) PP="$2"; shift 2;; @@ -99,12 +101,19 @@ parse_options() { fi # Verify required options are provided - if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then - echo "Usage: $0 --model= --quant= --tasks=" + if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then + echo "Usage: $0 --model= (--quant= | --recipe=) --tasks=" echo "Optional args: --sparsity= --awq_block_size= --calib=" exit 1 fi + # --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while + # --quant selects a built-in qformat preset. Pick exactly one. + if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then + echo "Cannot specify both --quant and --recipe; pick one." >&2 + exit 1 + fi + VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval") for task in $(echo "$TASKS" | tr ',' ' '); do @@ -135,6 +144,7 @@ parse_options() { echo "=================" echo "model: $MODEL_PATH" echo "quant: $QFORMAT" + echo "recipe: $RECIPE" echo "tp (TensorRT-LLM Checkpoint only): $TP" echo "pp (TensorRT-LLM Checkpoint only): $PP" echo "sparsity: $SPARSITY_FMT" diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml new file mode 100644 index 0000000000..749a875b36 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for expert layers only with MSE-search FP8 scale calibration on + expert weights, FP8 KV cache with constant amax (skips KV calibration; amax + hardcoded to FP8 E4M3 max 448.0). Expert weight quantizers are static + (per-block amax fixed by the MSE FP8-scale sweep); input quantizers remain + dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml new file mode 100644 index 0000000000..bca46608d5 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-fp8_cast_kv.yaml @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A4 for all MLP layers (dense + MoE) with MSE-search FP8 scale + calibration on MLP weights, FP8 KV cache with constant amax (skips KV + calibration; amax hardcoded to FP8 E4M3 max 448.0). MLP weight quantizers + are static (per-block amax fixed by the MSE FP8-scale sweep); input + quantizers remain dynamic. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + use_constant_amax: true + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false