Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

**New Features**

- Add NVFP4 W4A16 weight-only quantization (``nvfp4_w4a16``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.NVFP4_W4A16_CFG`` or ``--qformat nvfp4_w4a16`` in ``hf_ptq.py``. Exported checkpoints can be served on vLLM after conversion to compressed-tensors format.
- Register ``nn.Embedding`` with ``QuantModuleRegistry`` (weight-only wrapper) and extend the unified HF exporter to pack quantized embedding weights. Enables NVFP4 quantization of ``lm_head`` and the input token embedding on hybrid SSM+Attention models such as Nemotron-H, where those two tables are a sizeable fraction of parameters and leaving them in bf16 wastes most of the compression. Nemotron-H-specific enablement + ``--exclude_modules`` CLI flag wired up in ``examples/llm_ptq/hf_ptq.py``.
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
Expand Down
160 changes: 160 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,147 @@ def _resolve_file(filename):
module.__dict__.pop("weight", None)


def _maybe_patch_transformers_nemotron_h_mixer_types() -> None:
"""Patch transformers' Nemotron-H implementation for ``-`` (MLP) blocks.

transformers 5.5.x ships a Nemotron-H port that is incomplete in three places:

1. ``NemotronHConfig._pattern_to_list`` maps ``M→mamba``, ``E→moe``, ``*→attention``
but forgets ``-→mlp``, so merely loading the config of Nemotron-H-v2
(whose ``hybrid_override_pattern`` contains ``-``) raises ``KeyError: '-'``.
2. ``PreTrainedConfig.validate_layer_type`` checks ``layer_types`` (aliased to
``layers_block_type`` via ``attribute_map``) against a hard-coded
``ALLOWED_LAYER_TYPES`` tuple that doesn't include ``"mlp"``, so once (1) is
fixed the validator rejects the config.
3. ``MIXER_TYPES`` in ``modeling_nemotron_h`` registers ``mamba``/``attention``/``moe``
but omits ``mlp`` even though ``NemotronHMLP`` is defined in the same module.
``NemotronHBlock`` instantiates mixers as ``cls(config, layer_idx=...)``, which
``NemotronHMLP.__init__`` doesn't accept, so we register a thin adapter.

All patches are idempotent.
"""
# Extend ALLOWED_LAYER_TYPES so `validate_layer_type` accepts "mlp".
try:
cu = __import__("transformers.configuration_utils", fromlist=["ALLOWED_LAYER_TYPES"])
except ImportError:
cu = None
if cu is not None:
allowed = getattr(cu, "ALLOWED_LAYER_TYPES", None)
if isinstance(allowed, tuple) and "mlp" not in allowed:
cu.ALLOWED_LAYER_TYPES = (*allowed, "mlp")

# 1) MIXER_TYPES (modeling)
try:
mod = __import__(
"transformers.models.nemotron_h.modeling_nemotron_h",
fromlist=["MIXER_TYPES", "NemotronHMLP"],
)
except ImportError:
mod = None

if mod is not None:
mixer_types = getattr(mod, "MIXER_TYPES", None)
nemotron_h_mlp = getattr(mod, "NemotronHMLP", None)
if (
isinstance(mixer_types, dict)
and nemotron_h_mlp is not None
and "mlp" not in mixer_types
):
# ``nemotron_h_mlp`` is resolved at runtime, so use ``types.new_class`` rather
# than a literal ``class`` statement (keeps mypy happy about dynamic bases).
import types as _types

def _mlp_adapter_init(self, config, layer_idx=None, **kwargs):
nemotron_h_mlp.__init__(self, config, **kwargs)

_mlp_adapter_cls = _types.new_class(
"_NemotronHMLPMixerAdapter",
(nemotron_h_mlp,),
{},
lambda ns: ns.update({"__init__": _mlp_adapter_init}),
)
mixer_types["mlp"] = _mlp_adapter_cls

# ``NemotronHModel.forward`` builds an inline ``block_type_to_mask`` dict that
# only knows about ``{"mamba", "attention", "moe"}`` and KeyErrors on "mlp".
# ``NemotronHBlock.forward`` routes "mlp"/"moe" through the same ``else`` branch
# that ignores the attention mask, so aliasing the MLP block's ``block_type``
# to ``"moe"`` after __init__ makes the mask lookup resolve to ``None`` without
# affecting mixer dispatch (the mixer instance was already built from
# ``layers_block_type[layer_idx] == "mlp"`` via MIXER_TYPES).
nemotron_h_block = getattr(mod, "NemotronHBlock", None)
if nemotron_h_block is not None and not getattr(
nemotron_h_block, "_modelopt_mlp_mask_patched", False
):
_orig_init = nemotron_h_block.__init__

def _patched_init(self, config, layer_idx):
_orig_init(self, config, layer_idx)
if getattr(self, "block_type", None) == "mlp":
self.block_type = "moe"

nemotron_h_block.__init__ = _patched_init
nemotron_h_block._modelopt_mlp_mask_patched = True

# 2) NemotronHConfig._pattern_to_list + validate_layers_block_type (configuration).
try:
cfg_mod = __import__(
"transformers.models.nemotron_h.configuration_nemotron_h",
fromlist=["NemotronHConfig"],
)
except ImportError:
return
cfg_cls = getattr(cfg_mod, "NemotronHConfig", None)
if cfg_cls is None or getattr(cfg_cls, "_modelopt_mlp_patched", False):
return

_orig_pattern_to_list = cfg_cls._pattern_to_list

def _patched_pattern_to_list(pattern: str) -> list:
mapping = {"M": "mamba", "E": "moe", "*": "attention", "-": "mlp"}
try:
return [mapping[ch] for ch in pattern]
except KeyError:
# Fall back to the stock implementation for any char we didn't add —
# this lets future transformers releases keep any additional mappings.
return _orig_pattern_to_list(pattern)

# Assign via ``staticmethod()`` so the attribute is unbound on the class (matches
# the original definition) — using the ``@staticmethod`` decorator on a nested
# function trips mypy's "staticmethod used with a non-method" check.
cfg_cls._pattern_to_list = staticmethod(_patched_pattern_to_list)

# Allow "mlp" alongside {"mamba", "attention", "moe"} in validate_layers_block_type.
# huggingface_hub's @strict_dataclass collects class validators into
# ``cls.__class_validators__`` at class-creation time, so we have to replace the
# entry in that list (not just overwrite the method attribute).
def _patched_validate_layers_block_type(self):
if not isinstance(self.layers_block_type, list):
raise ValueError(
f"`layers_block_type` must be a list of strings. "
f"Got type: {type(self.layers_block_type)}"
)
valid_types = {"mamba", "attention", "moe", "mlp"}
if not all(block_type in valid_types for block_type in self.layers_block_type):
invalid = set(self.layers_block_type) - valid_types
raise ValueError(
f"`layers_block_type` contains invalid types: {invalid}. "
f"Must be one of: {valid_types}"
)

_patched_validate_layers_block_type.__name__ = "validate_layers_block_type"
cfg_cls.validate_layers_block_type = staticmethod(_patched_validate_layers_block_type)
class_validators = list(getattr(cfg_cls, "__class_validators__", []))
for i, v in enumerate(class_validators):
if getattr(v, "__name__", None) == "validate_layers_block_type":
class_validators[i] = _patched_validate_layers_block_type
break
else:
class_validators.append(_patched_validate_layers_block_type)
cfg_cls.__class_validators__ = class_validators
cfg_cls._modelopt_mlp_patched = True


def get_model(
ckpt_path,
device="cuda",
Expand All @@ -548,6 +689,9 @@ def get_model(
use_seq_device_map=False,
attn_implementation=None,
):
# Needs to run before AutoConfig.from_pretrained so the Nemotron-H config can parse
# the "-" (MLP) character in hybrid_override_pattern.
_maybe_patch_transformers_nemotron_h_mixer_types()
print(f"Initializing model from {ckpt_path}")

device_map = "auto"
Expand Down Expand Up @@ -706,6 +850,22 @@ def has_pack_quantized_config(config):
if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

# Some model cards ship a generation_config.json that sets sampling hyperparameters
# (top_p, temperature) without ``do_sample=True`` (e.g. NVIDIA-Nemotron-3-Nano-4B-BF16).
# transformers 5.x strictly validates this on save_pretrained, so the export step
# fails with "GenerationConfig is invalid". Normalize by enabling do_sample whenever
# a sampling hyperparameter is set — this is only metadata, not behavior during
# calibration or export.
gen_cfg = getattr(model, "generation_config", None)
if gen_cfg is not None and not getattr(gen_cfg, "do_sample", False):
has_sampling_hyperparam = (
getattr(gen_cfg, "top_p", None) not in (None, 1.0)
or getattr(gen_cfg, "top_k", None) not in (None, 0, 50)
or getattr(gen_cfg, "temperature", None) not in (None, 1.0)
)
if has_sampling_hyperparam:
gen_cfg.do_sample = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd examples/llm_ptq && head -n 1 example_utils.py && wc -l example_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 186


🏁 Script executed:

cd examples/llm_ptq && sed -n '830,880p' example_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2298


🏁 Script executed:

cd examples/llm_ptq && grep -n "def get_model" example_utils.py | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 84


🏁 Script executed:

cd examples/llm_ptq && rg "get_model\(" hf_ptq.py -A 5 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 430


🏁 Script executed:

cd examples/llm_ptq && rg "\.generate\(" hf_ptq.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1010


🏁 Script executed:

cd examples/llm_ptq && sed -n '1,50p' hf_ptq.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1060


🏁 Script executed:

cd examples/llm_ptq && grep -n "full_model" hf_ptq.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1080


🏁 Script executed:

cd examples/llm_ptq && sed -n '300,350p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1721


🏁 Script executed:

cd examples/llm_ptq && sed -n '334,345p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 465


🏁 Script executed:

cd examples/llm_ptq && sed -n '350,365p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 645


🏁 Script executed:

cd examples/llm_ptq && grep -n "generated_ids_before_ptq" hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 816


🏁 Script executed:

cd examples/llm_ptq && sed -n '257,270p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 653


🏁 Script executed:

cd examples/llm_ptq && sed -n '580,600p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 689


🏁 Script executed:

cd examples/llm_ptq && sed -n '880,925p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1892


🏁 Script executed:

cd examples/llm_ptq && sed -n '1100,1125p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1108


🏁 Script executed:

cd examples/llm_ptq && sed -n '970,1020p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2331


🏁 Script executed:

cd examples/llm_ptq && sed -n '1180,1220p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1472


Don't mutate the live generation_config in get_model().

The mutation persists on the returned model object, and both the before-PTQ and after-PTQ preview calls (full_model.generate() at lines 922 and 980 in hf_ptq.py) use that same model instance. For checkpoints with sampling hyperparameters, this makes the previews non-deterministic instead of deterministic, undermining PTQ smoke test comparisons. Normalize a copy during export instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/example_utils.py` around lines 853 - 867, The current code
mutates the live model.generation_config (gen_cfg) which makes the same model
instance used by get_model() non-deterministic; instead, create a copy of the
generation_config (e.g., via copy.deepcopy or by constructing a new
GenerationConfig from the dict) and modify the copy’s do_sample flag, leaving
model.generation_config unchanged; update the export/normalization logic around
gen_cfg to use this gen_cfg_copy (or a temporary variable) so
previews/full_model.generate() remain deterministic and only the exported
metadata contains the normalized setting.


return model


Expand Down
101 changes: 101 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_w4a16": mtq.NVFP4_W4A16_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
Expand Down Expand Up @@ -593,6 +594,59 @@ def sparsity_main(
mts.export(full_model)


def _enable_lm_head_and_embedding_quantization(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this in the modelop_recipe if everything modelopt_recipes/models can be captured with our yaml recipe system?

quant_cfg: dict[str, Any],
weight_quantizer_cfg: dict[str, Any],
) -> None:
"""Re-enable quantization of ``lm_head`` and the input embedding table.

ModelOpt's default PTQ recipes exclude ``*lm_head*`` and never touch ``nn.Embedding``
because most LLM deployment runtimes keep those layers at full precision. For Nemotron-H
(and similar SSM+Attention hybrids) the embedding and lm_head are a large fraction of the
total parameters — quantizing them recovers most of the promised memory savings. This
helper appends two entries to the cfg list that override earlier ``*lm_head*`` disables
and explicitly target the embedding weight quantizer.

Args:
quant_cfg: the primary quant_cfg dict (``{"quant_cfg": [...], "algorithm": ...}``).
weight_quantizer_cfg: the weight-quantizer attribute dict to apply (e.g. ``_nvfp4_cfg``).
"""
# Ordering matters: these entries must come AFTER the _default_disabled_quantizer_cfg
# entries (which set *lm_head* → disabled) so they take effect.
quant_cfg["quant_cfg"].append(
{"quantizer_name": "*lm_head*weight_quantizer", "cfg": copy.deepcopy(weight_quantizer_cfg)}
)
# nn.Embedding quantizers only exist once `quant_embedding.py` registers the class.
# Nemotron-H's backbone attribute name differs between the remote-code ("backbone.embeddings")
# and transformers built-in ("model.embeddings") paths; both are weight-only vocab
# embeddings here. The broad "*embeddings*" wildcard covers both and does not match
# any other layer in a Nemotron-H model (no positional/rotary embeddings exist).
quant_cfg["quant_cfg"].append(
{
"quantizer_name": "*embeddings*weight_quantizer",
"cfg": copy.deepcopy(weight_quantizer_cfg),
}
)
# Also keep the standard HF "embed_tokens" naming in case future Nemotron-H variants
# rename the attribute.
quant_cfg["quant_cfg"].append(
{
"quantizer_name": "*embed_tokens*weight_quantizer",
"cfg": copy.deepcopy(weight_quantizer_cfg),
}
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


def _extract_weight_quantizer_cfg(quant_cfg: dict[str, Any]) -> dict[str, Any] | None:
"""Return the first ``*weight_quantizer`` cfg dict from an ordered quant_cfg list."""
for entry in quant_cfg.get("quant_cfg", []):
if entry.get("quantizer_name") == "*weight_quantizer" and isinstance(
entry.get("cfg"), dict
):
return entry["cfg"]
return None


def mono_quantize(
args: argparse.Namespace,
quant_cfg: dict[str, Any],
Expand Down Expand Up @@ -629,6 +683,24 @@ def mono_quantize(
) # Nemotron-Parse specific
print("Quantization will only be applied to the decoder (text generation) component")

# For Nemotron-H (Mamba-2 + MLP + Attention hybrid, e.g. NVIDIA-Nemotron-3-Nano-4B),
# extend quantization coverage to the lm_head and the input token embedding. On this
# architecture those two 131072x3136 tables account for ~21% of parameters, so leaving
# them at bf16 wastes most of the NVFP4 memory benefit.
if model_type == "nemotron_h":
weight_quantizer_cfg = _extract_weight_quantizer_cfg(quant_cfg)
if weight_quantizer_cfg is not None:
print(
"Nemotron-H detected: extending quantization to lm_head and input embedding "
"(backbone.embeddings)."
)
_enable_lm_head_and_embedding_quantization(quant_cfg, weight_quantizer_cfg)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
else:
warnings.warn(
"Nemotron-H detected but quant_cfg has no wildcard '*weight_quantizer' entry; "
"skipping lm_head/embedding extension (model-specific or non-standard recipe)."
)

if not model_is_already_quantized or calibration_only:
# quantize the model

Expand Down Expand Up @@ -781,6 +853,12 @@ def export_quantized(
extra_state_dict=mtp_state_dict,
)

if args.qformat == "nvfp4_w4a16":
warnings.warn(
"TensorRT-LLM and SGLang do not support this format. "
"To serve on vLLM, convert the NVFP4 W4A16 checkpoint to compressed-tensors format."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ajrasane , should we point the users to how they can convert? do we have a helper in ModelOpt we should point them to?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hychiang-git, are you planning to merge your conversion script to modelopt?

)

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
Expand Down Expand Up @@ -1106,6 +1184,18 @@ def quantize_main(
quant_cfg["quant_cfg"].append({"quantizer_name": pattern, "enable": False})
print(f"Excluding MTP layer from quantization: {pattern}")

# Apply user-requested per-module exclusions (--exclude_modules).
if args.exclude_modules:
quant_cfg = copy.deepcopy(quant_cfg)
for mod in args.exclude_modules:
quant_cfg["quant_cfg"].append(
{"quantizer_name": f"*{mod}*.weight_quantizer", "enable": False}
)
quant_cfg["quant_cfg"].append(
{"quantizer_name": f"*{mod}*.input_quantizer", "enable": False}
)
print(f"Excluding module from quantization: {mod}")

# Use constant amax for KV quantizers when a cast format is selected.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
Expand Down Expand Up @@ -1304,6 +1394,17 @@ def parse_args() -> argparse.Namespace:
default=False,
action="store_true",
)
parser.add_argument(
"--exclude_modules",
nargs="+",
default=[],
metavar="MODULE",
help=(
"Module name patterns to exclude from quantization "
"(e.g. lm_head backbone.layers.0.mixer). "
"Appends a disable rule for each pattern's weight and input quantizers."
),
)
parser.add_argument(
"--low_memory_mode",
help=(
Expand Down
14 changes: 12 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian | nvfp4_w4a16) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian, nvfp4_w4a16]" >&2
exit 1
;;
esac
Expand Down Expand Up @@ -127,6 +127,10 @@ if $TRUST_REMOTE_CODE; then
PTQ_ARGS+=" --trust_remote_code "
fi

if [ -n "${EXCLUDE_MODULES:-}" ]; then
PTQ_ARGS+=" --exclude_modules ${EXCLUDE_MODULES} "
fi
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

if $USE_SEQ_DEVICE_MAP; then
PTQ_ARGS+=" --use_seq_device_map "
fi
Expand Down Expand Up @@ -199,6 +203,12 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi

if [ "$QFORMAT" = "nvfp4_w4a16" ]; then
echo "nvfp4_w4a16 checkpoint exported to $SAVE_PATH"
echo "To serve on vLLM, convert to compressed-tensors"
exit 0
fi

if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)

Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/export/convert_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None)
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs},
}
elif quant_algo == "NVFP4_W4A16":
gs = group_size or 16
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
}
elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"):
gs = group_size or 128
return {
Expand Down Expand Up @@ -183,6 +188,14 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "NVFP4_W4A16":
# Weight-only FP4
group_size = original_quantization_details.get("group_size", 16)
config_group_details = {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size},
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif quant_algo_value == "MIXED_PRECISION":
quantized_layers = original_quantization_details.get("quantized_layers", {})

Expand Down
Loading
Loading