Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ For additional information and usage, run:
<td>nemo2riva --key tlt_encode --out &lt;path to save .riva model&gt; --onnx-opset 14 &lt;path to .nemo model&gt;</td>
</tr>
<tr>
<td>MagpieTTS Decoder</td>
<td>nemo2riva &lt;model.ckpt path$gt; --load_ckpt --model_config &lt;hparams_file$gt; --audio_codecpath &lt;codec .nemo ckpt&gt; --key tlt_encode --out magpie_decoder.riva --submodel decoder</td>
</tr>
<tr>
<td>MagpieTTS Encoder</td>
<td>nemo2riva &lt;model.ckpt path&gt; --load_ckpt --model_config &lt;hparams_file&gt; --audio_codecpath &lt;codec .nemo ckpt&gt; --key tlt_encode --out magpie_encoder.riva --submodel encoder</td>
</tr>
<tr>
<td rowspan="2">Voice Activity Detection</td>
<td>Segment VAD</td>
<td rowspan="2">nemo2riva --key tlt_encode --out &lt;path to save .riva model&gt; --onnx-opset 18 &lt;path to .nemo model&gt;</td>
Expand Down
6 changes: 5 additions & 1 deletion nemo2riva/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ def get_args(argv):
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=f"Convert NeMo models to Riva EFF input format",
)
parser.add_argument("source", help="Source .nemo file")
parser.add_argument("source", help="Source .nemo or ckpt file")
parser.add_argument("--out", default=None, help="Location to write resulting Riva EFF input to")
parser.add_argument("--load_ckpt", action="store_true", help="Load using checkpoint instead of .nemo file")
parser.add_argument("--submodel", default="decoder", help="Submodel to export. Default is decoder for MagpieTTSModel.")
parser.add_argument("--model_config", default=None, help="Hparams file")
parser.add_argument("--audio_codecpath", default=None, help="Audiocodec path. Needed only for magpietts models.")
parser.add_argument("--validate", action="store_true", help="Validate using schemas")
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result")
parser.add_argument("--schema", default=None, help="Schema file to use for validation")
Expand Down
41 changes: 37 additions & 4 deletions nemo2riva/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ def retrieve_artifacts_as_dict(restore_path: str, obj: Optional["ModelPT"] = Non
logging.error(f"Could not retrieve the artifact {file_key} at {member.name}. Error occured:\n{tb}")
return artifacts

def retrieve_artifacts_as_dict_from_ckpt(ckpt: dict, model_cfg: dict, obj: Optional["ModelPT"] = None):
""" Retrieves all NeMo artifacts and returns them as dict
Args:
ckpt: dict containing the checkpoint
model_cfg: dict containing the model config
obj: ModelPT object (Optional, DEFAULT: None)
"""
artifacts = {}
## ckpt
#f = open(ckpt, "rb")
#artifact_content = f.read()
#aname = "model.ckpt"
#artifacts[aname] = {
# "conf_path": aname,
# "path_type": "TAR_PATH",
# "content": artifact_content,
#}
#f.close()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this code block if commented and use context manager for file operation

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@styagi130 could you review this comment and update as needed?


f = open(model_cfg, "rb")
artifact_content = f.read()
aname = "model_config.yaml"
artifacts[aname] = {
"conf_path": aname,
"path_type": "TAR_PATH",
"content": artifact_content,
}
f.close()
return artifacts


def create_artifact(reg, key, do_encrypt, **af_dict):
# only works for plain content now - no encryption in Nemo
Expand All @@ -93,8 +123,11 @@ def create_artifact(reg, key, do_encrypt, **af_dict):
return af


def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs):
artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path)
def get_artifacts(restore_path: str, model=None, passphrase=None, model_cfg=None, from_ckpt=False, **patch_kwargs):
if not from_ckpt:
artifacts = retrieve_artifacts_as_dict(obj=model, restore_path=restore_path)
else:
artifacts = retrieve_artifacts_as_dict_from_ckpt(ckpt=restore_path, model_cfg=model_cfg, obj=model)

# NOTE: when servicemaker calls into get_artifacts, model is always None so this code section
# is never run.
Expand All @@ -119,7 +152,7 @@ def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs
nemo_manifest = {'files': artifacts, 'metadata': {'format_version': 1}}
if 'model_config.yaml' in artifacts.keys():
nemo_manifest['has_nemo_config'] = True

nemo_files = nemo_manifest['files']
nemo_metadata = nemo_manifest['metadata']
reg = ArtifactRegistry(passphrase=passphrase)
Expand All @@ -134,4 +167,4 @@ def get_artifacts(restore_path: str, model=None, passphrase=None, **patch_kwargs
create_artifact(reg, key, False, content_callback=cb_override, **af_dict)

logging.info(f"Retrieved artifacts: {artifacts.keys()}")
return reg, nemo_manifest
return reg, nemo_manifest
41 changes: 37 additions & 4 deletions nemo2riva/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
from nemo.core import ModelPT
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict
from lightning.pytorch import Trainer


from nemo2riva.artifacts import get_artifacts
from nemo2riva.cookbook import export_model, save_archive
from nemo2riva.schema import get_import_config, get_subnet, validate_archive



def Nemo2Riva(args):
"""Convert a .nemo saved model into .riva Riva input format."""
nemo_in = args.source
Expand Down Expand Up @@ -48,8 +50,37 @@ def Nemo2Riva(args):

try:
with torch.inference_mode():
# Restore instance from .nemo file using generic model restore_from
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
if args.load_ckpt:
if not args.model_config:
raise ValueError("Hparams file is required when loading from checkpoint")
model_cfg = OmegaConf.load(args.model_config)
ckpt = torch.load(nemo_in, weights_only=False)
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]

if "cfg" in model_cfg:
model_cfg = model_cfg.cfg
with open_dict(model_cfg):
if model_cfg.target.split(".")[-1] == "MagpieTTSModel":
from nemo2riva.patches.tts.magpietts import update_config, update_ckpt
from nemo.collections.tts.models.magpietts import MagpieTTSModel
legacy_codebooks = False
if not args.audio_codecpath:
raise ValueError("Audio codec path is required when loading from checkpoint for MagpieTTSModel.")
model_cfg = update_config(model_cfg, args.audio_codecpath, legacy_codebooks)
state_dict = update_ckpt(ckpt)

model = MagpieTTSModel(cfg=model_cfg)
model.load_state_dict(state_dict)
model.cuda()
model.eval()
model = model.half()
else:
model = ModelPT(cfg=model_cfg)
model.load_state_dict(ckpt)
else:
# Restore instance from .nemo file using generic model restore_from
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
Expand Down Expand Up @@ -78,9 +109,11 @@ def Nemo2Riva(args):
warnings.filterwarnings('ignore', category=UserWarning)
# TODO: revisit export_subnet cli arg
patch_kwargs = {"import_config" : cfg}
if model.__class__.__name__ == "MagpieTTSModel":
patch_kwargs['is_encoder'] = args.submodel == "encoder"
if args.export_subnet:
patch_kwargs['export_subnet'] = args.export_subnet
artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, **patch_kwargs)
artifacts, manifest = get_artifacts(restore_path=nemo_in, model=model, passphrase=key, model_cfg=args.model_config, from_ckpt=args.load_ckpt, **patch_kwargs)

for export_cfg in cfg.exports:
subnet = get_subnet(model, export_cfg.export_subnet)
Expand Down
45 changes: 34 additions & 11 deletions nemo2riva/cookbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def export_model(model, cfg, args, artifacts, metadata):
export_filename = cfg.export_file
export_file = os.path.join(tmpdir, export_filename)

if cfg.export_format in ["ONNX", "TS"]:
if cfg.export_format in ["ONNX", "TS"] and (model.__class__.__name__ == "MagpieTTSModel" and args.submodel == "encoder"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be or ?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the latest commit

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@virajkarandikar good to resolve?

# Export the model, get the descriptions.
if not isinstance(model, Exportable):
if not isinstance(model, Exportable) and not model.__class__.__name__ == "MagpieTTSModel":
logging.error("Your NeMo model class ({}) is not Exportable.".format(metadata['obj_cls']))
sys.exit(1)

Expand All @@ -86,15 +86,38 @@ def export_model(model, cfg, args, artifacts, metadata):
if cfg.max_dim is not None:
in_args["max_dim"] = cfg.max_dim

input_example = model.input_module.input_example(**in_args)
_, descriptions = model.export(
export_file,
input_example=input_example,
check_trace=args.runtime_check,
onnx_opset_version=args.onnx_opset,
verbose=bool(args.verbose),
)
del model
if model.__class__.__name__ == "MagpieTTSModel" and cfg.export_format == "ONNX":
from nemo2riva.patches.tts.magpietts import EncoderOnnxModel
with torch.no_grad():
model.eval()
model = model.half()
encoder_model = EncoderOnnxModel(model)
input_args, dynamic_axes, output_names, input_names = encoder_model._prepare_for_export()

torch.onnx.export(encoder_model,
input_args,
export_file,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17)

enc_gs = gs.import_onnx(onnx.load(export_file))
outputs = enc_gs.outputs
fix_outputs = [outputs[0]]
enc_gs.outputs = fix_outputs
onnx.save(gs.export_onnx(enc_gs), export_file)
del model, encoder_model
else:
input_example = model.input_module.input_example(**in_args)
_, descriptions = model.export(
export_file,
input_example=input_example,
check_trace=args.runtime_check,
onnx_opset_version=args.onnx_opset,
verbose=bool(args.verbose),
)
del model
if cfg.export_format == 'ONNX':
o_list = os.listdir(tmpdir)
save_as_external_data = len(o_list) > 1
Expand Down
3 changes: 2 additions & 1 deletion nemo2riva/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from nemo2riva.patches.ctc import set_decoder_num_classes
from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version
from nemo2riva.patches.mtencdec import change_tokenizer_names
from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning
from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning, magpietts_model_versioning

patches = {
"EncDecCTCModel": [set_decoder_num_classes],
"EncDecCTCModelBPE": [bpe_check_inputs_and_version],
"MTEncDecModel": [change_tokenizer_names],
"FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning],
"RadTTSModel": [generate_vocab_mapping, radtts_model_versioning],
"MagpieTTSModel": [magpietts_model_versioning],
}
4 changes: 3 additions & 1 deletion nemo2riva/patches/tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from nemo2riva.patches.tts.fastpitch import fastpitch_model_versioning
from nemo2riva.patches.tts.general import generate_vocab_mapping
from nemo2riva.patches.tts.radtts import radtts_model_versioning
from nemo2riva.patches.tts.magpietts import magpietts_model_versioning

__all__ = [
fastpitch_model_versioning,
generate_vocab_mapping,
radtts_model_versioning
radtts_model_versioning,
magpietts_model_versioning
]
7 changes: 7 additions & 0 deletions nemo2riva/patches/tts/magpieTTS_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Decoder
Command to export decoder:
`nemo2riva <model.ckpt path> --load_ckpt --model_config <hparams_file> --audio_codecpath <codec .nemo ckpt> --key tlt_encode --out magpie_decoder.riva --submodel decoder`

# Encoder
Command to export encoder:
`nemo2riva <model.ckpt path> --load_ckpt --model_config <hparams_file> --audio_codecpath <codec .nemo ckpt> --key tlt_encode --out magpie_encoder.riva --submodel encoder`
Loading