Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions nemo2riva/cookbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def export_model(model, cfg, args, artifacts, metadata):
format_meta = {"has_pytorch_checkpoint": True, "runtime": "PyTorch"}
elif cfg.export_format == "NEMO":
format_meta = {"has_pytorch_checkpoint": True, "runtime": "Python"}
elif cfg.export_format == "STATE":
Comment thread
anand-nv marked this conversation as resolved.
format_meta = {"has_pytorch_checkpoint": False, "runtime": "Python"}
# TODO: use submodel sections
metadata.update(format_meta)
runtime = format_meta["runtime"]
Expand Down Expand Up @@ -140,6 +142,15 @@ def export_model(model, cfg, args, artifacts, metadata):

elif cfg.export_format == "NEMO":
model.save_to(export_file)
elif cfg.export_format == "STATE":
if not isinstance(model, Exportable):
logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls']))
sys.exit(1)
model.freeze()
model_params = model.state_dict()
torch.save(model_params, export_file)



# Add exported file to the artifact registry

Expand Down
2 changes: 2 additions & 0 deletions nemo2riva/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from nemo2riva.patches.ctc import set_decoder_num_classes
from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version
from nemo2riva.patches.aed_canary import config_for_trtllm
from nemo2riva.patches.mtencdec import change_tokenizer_names
from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning

patches = {
"EncDecCTCModel": [set_decoder_num_classes],
"EncDecCTCModelBPE": [bpe_check_inputs_and_version],
"EncDecMultiTaskModel": [config_for_trtllm],
"MTEncDecModel": [change_tokenizer_names],
"FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning],
"RadTTSModel": [generate_vocab_mapping, radtts_model_versioning],
Expand Down
78 changes: 78 additions & 0 deletions nemo2riva/patches/aed_canary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment thread
anand-nv marked this conversation as resolved.
Outdated
# SPDX-License-Identifier: MIT

import yaml
import json
Comment thread
anand-nv marked this conversation as resolved.
Outdated
import logging


def config_for_trtllm(model, artifacts, **kwargs):
if model.__class__.__name__ == 'EncDecMultiTaskModel':

model_config = yaml.safe_load(artifacts['model_config.yaml']['content'])

keys_required = [
'beam_search',
'encoder',
'head',
'model_defaults',
'prompt_format',
'sample_rate',
'target',
'preprocessor',
]
if 'beam_search' not in model_config and 'decoding' in model_config:
model_config['beam_search'] = model_config['decoding'].get('beam', {'beam_size': 1, 'len_pen': 0.0,
'max_generation_delta': 50}
)
config = dict({k: model_config[k] for k in keys_required})
config['decoder'] = {
'transf_decoder': model_config['transf_decoder'],
'transf_encoder': model_config['transf_encoder'],
'vocabulary': make_vocabulary_file(model,artifacts),
'num_classes': model_config['head']['num_classes'],
'feat_in': model_config['model_defaults']['asr_enc_hidden'],
'n_layers': model_config['transf_decoder']['config_dict']['num_layers'],
}
config['target'] = 'trtllm.canary'

Comment thread
anand-nv marked this conversation as resolved.

artifacts['model_config.yaml']['content'] = yaml.safe_dump(config, encoding=('utf-8'))


def make_vocabulary_file(model, artifacts, **kwargs):
if model.__class__.__name__ == 'EncDecMultiTaskModel':

tokenizer_vocab = {'tokens': {},
'offsets': model.tokenizer.token_id_offset
}
for lang in model.tokenizer.langs:
tokenizer_vocab['tokens'][lang] = {}
tokenizer_vocab['size'] = model.tokenizer.vocab_size

try:
tokenizer_vocab['bos_id'] = model.tokenizer.bos_id
except Exception as e:
logging.warning(f"Tokenizer is missing bos_id. Could affect accuracy")

try:
tokenizer_vocab['eos_id'] = model.tokenizer.eos_id
except Exception as e:
logging.warning(f"Tokenizer is missing eos_id. Could affect accuracy")
try:
tokenizer_vocab['nospeech_id'] = model.tokenizer.nospeech_id
except Exception as e:
logging.warning(f"Tokenizer is missing nospeech_id. Could affect accuracy")
try:
tokenizer_vocab['pad_id'] = model.tokenizer.pad_id
except Exception as e:
logging.warning(f"Tokenizer is missing pad_id. Could affect accuracy")

for t_id in range(0, model.tokenizer.vocab_size):
lang = model.tokenizer.ids_to_lang([t_id])
tokenizer_vocab['tokens'][lang][t_id] = model.tokenizer.ids_to_tokens([t_id])[0]

#artifacts['vocab.json']={}
Comment thread
anand-nv marked this conversation as resolved.
Outdated
#artifacts['vocab.json']['content'] = json.dumps(tokenizer_vocab).encode('utf-8')
return tokenizer_vocab

21 changes: 18 additions & 3 deletions nemo2riva/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

schema_dict = None

supported_formats = ["ONNX", "CKPT", "TS", "NEMO"]
supported_formats = ["ONNX", "CKPT", "TS", "NEMO", "PYTORCH", "STATE"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add separate "PYTORCH" as supported format or can we use "CKPT" for it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PYTORCH in case we are just exporting part of the model or tensors.



@dataclass
Expand Down Expand Up @@ -48,15 +48,30 @@ def get_export_config(export_obj, args):
need_autocast = False
if export_obj:
conf.export_file = list(export_obj)[0]
attribs = export_obj[conf.export_file]
conf.export_subnet = attribs.get('export_subnet', None)

conf.is_onnx=attribs.get('onnx', False)



Comment thread
anand-nv marked this conversation as resolved.
Outdated
if not conf.is_onnx:
conf.states_only = attribs.get('states_only', False)
conf.is_torch = attribs.get('torch', False)

if conf.export_file.endswith('.onnx'):
conf.export_format = "ONNX"
elif conf.export_file.endswith('.ts'):
conf.export_format = "TS"
elif conf.export_file.endswith('.nemo'):
conf.export_format = "NEMO"
elif conf.is_torch:
if conf.states_only:
conf.export_format = "STATE"
else:
conf.export_format = "PYTORCH"
else:
conf.export_format = "CKPT"
attribs = export_obj[conf.export_file]
conf.autocast = attribs.get('autocast', False)
need_autocast = conf.autocast

Expand All @@ -66,7 +81,7 @@ def get_export_config(export_obj, args):
if conf.encryption and args.key is None:
raise Exception(f"{conf.export_file} requires encryption and no key was given")

conf.export_subnet = attribs.get('export_subnet', None)

Comment thread
anand-nv marked this conversation as resolved.
Outdated

if args.export_subnet:
if conf.export_subnet:
Expand Down
Empty file added nemo2riva/scripts/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions nemo2riva/validation_schemas/asr-scr-exported-aedmodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment thread
anand-nv marked this conversation as resolved.
Outdated
# SPDX-License-Identifier: MIT

# Define required metadata fields expected in the archive (optional).
metadata:
- obj_cls: nemo.collections.asr.models.EncDecMultiTaskModel


# Define list of files that are expected (optional).
artifact_properties:
# List of files.
- model_config.yaml
- encoder.onnx:
export_subnet: encoder
onnx: True
- decoder.pt:
export_subnet: transf_decoder
states_only: True
torch: True
onnx: False
- log_softmax.pt:
export_subnet: log_softmax
states_only: True
torch: True
onnx: False


# Define list of files with expected content (optional).
# Functionality limited to yaml files (e.g. model_config.yaml).
artifact_content:
# List of files.
- model_config.yaml:
# List of sections.subsections. ... that are required.
# (Optional `: True` instructs to check the presence of the file in indicated as leaf in the archive)
- transf_decoder
- transf_encoder
- vocabulary
- num_classes
- feat_in
- n_layers
- target
- beam_search
- encoder
- head
- model_defaults
- prompt_format
- sample_rate
- target
- preprocessor
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# SPDX-License-Identifier: MIT

nemo_toolkit>=1.6.0
torch>=2.4.0
nvidia-eff>=0.6.4
nvidia-eff-tao-encryption>=0.1.8
nvidia-pyindex==1.0.6
onnx==1.14.1
onnx==1.16.1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would the same versions as Riva container work here?

torch 2.5.0
onnx 1.17.0

will be good them same to avoid any possible discrepancies due to version mismatches b/w nemo and Riva.

onnxruntime==1.16.3
onnxruntime-gpu==1.16.3
onnx-graphsurgeon==0.3.27