diff --git a/nemo2riva/patches/__init__.py b/nemo2riva/patches/__init__.py index a70d4ce..6068179 100644 --- a/nemo2riva/patches/__init__.py +++ b/nemo2riva/patches/__init__.py @@ -5,6 +5,7 @@ from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version from nemo2riva.patches.mtencdec import change_tokenizer_names from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning +from nemo2riva.patches.frame_vad import patch_output_name patches = { "EncDecCTCModel": [set_decoder_num_classes], @@ -12,4 +13,5 @@ "MTEncDecModel": [change_tokenizer_names], "FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning], "RadTTSModel": [generate_vocab_mapping, radtts_model_versioning], + "EncDecFrameClassificationModel": [patch_output_name], } diff --git a/nemo2riva/patches/frame_vad.py b/nemo2riva/patches/frame_vad.py new file mode 100644 index 0000000..cc70d98 --- /dev/null +++ b/nemo2riva/patches/frame_vad.py @@ -0,0 +1,19 @@ +from nemo.core.neural_types.neural_type import NeuralType +from typing import Any, Dict, Optional +from nemo.core.neural_types.elements import LogitsType +from nemo.collections.asr.models import EncDecClassificationModel +from nemo.core.utils.neural_type_utils import get_io_names + +def patch_output_name(model, artifacts, **kwargs): + if model.__class__.__name__ == "EncDecFrameClassificationModel": + @property + def output_names(self): + return ["logits"] + model.__class__.output_names = output_names + + + + + + +