Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions nemo_deploy/llm/inference/inference_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,27 @@ def setup_megatron_model_and_tokenizer_for_inference(
torch_distributed_init(dist_config)
model_config, mlm_args = load_model_config(checkpoint_path)

# MLA models require cache_mla_latents=True for the dynamic inference backend.
# The checkpoint may have saved it as False (training default), but inference
# with the dynamic engine always needs it enabled.
if hasattr(model_config, "cache_mla_latents"):
model_config.cache_mla_latents = True

# Convert attention_backend from string to enum if needed
if hasattr(model_config, "attention_backend") and isinstance(model_config.attention_backend, str):
if model_config.attention_backend == "AttnBackend.fused":
model_config.attention_backend = AttnBackend.fused
elif model_config.attention_backend == "AttnBackend.flash":
model_config.attention_backend = AttnBackend.flash
elif model_config.attention_backend == "AttnBackend.unfused":
model_config.attention_backend = AttnBackend.unfused
elif model_config.attention_backend == "AttnBackend.local":
model_config.attention_backend = AttnBackend.local
elif model_config.attention_backend == "AttnBackend.auto":
if hasattr(model_config, "attention_backend"):
if model_config.attention_backend is None:
# Deserialization of the AttnBackend enum failed (e.g. Hydra _target_ dict
# not reconstructed); fall back to auto so the engine can pick the best backend.
model_config.attention_backend = AttnBackend.auto
elif isinstance(model_config.attention_backend, str):
_str_to_attn_backend = {
"AttnBackend.fused": AttnBackend.fused,
"AttnBackend.flash": AttnBackend.flash,
"AttnBackend.unfused": AttnBackend.unfused,
"AttnBackend.local": AttnBackend.local,
"AttnBackend.auto": AttnBackend.auto,
}
model_config.attention_backend = _str_to_attn_backend.get(model_config.attention_backend, AttnBackend.auto)

if tensor_model_parallel_size is not None:
model_config.tensor_model_parallel_size = tensor_model_parallel_size
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,14 @@ override-dependencies = [
"flash-linear-attention>=0.3.0,<0.4.dev0",
"patchelf; sys_platform=='never'",
"nvidia-resiliency-ext>=0.3.0,<0.6.0",
"transformer-engine[pytorch,core_cu13]>=2.12.0a0,<2.15.0; sys_platform != 'darwin'",
"transformer-engine-cu13>=2.12.0a0,<2.15.0; sys_platform != 'darwin'",
"transformer-engine-cu12; sys_platform == 'never'",
# The custom-built TE in the container already includes the torch extension natively.
# Installing transformer-engine-torch from PyPI creates a dist-info that triggers TE's
# sanity check requiring the base package to also be a PyPI wheel, which fails for
# source/custom builds. Since the .so is already present, skip the PyPI package.
"transformer-engine-torch; sys_platform == 'never'",
"mamba-ssm>=2.3.0,<2.4.0",
"transformers>=5.0.0",
"transformers==5.2.0",
"protobuf~=6.33.5",
"opencv-python-headless; sys_platform == 'never'",
"cryptography>=43.0.0,<47",
Expand Down
Loading
Loading