Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo2riva/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from nemo2riva.patches.ctc_bpe import bpe_check_inputs_and_version
from nemo2riva.patches.mtencdec import change_tokenizer_names
from nemo2riva.patches.tts import fastpitch_model_versioning, generate_vocab_mapping, radtts_model_versioning
from nemo2riva.patches.frame_vad import patch_output_name

patches = {
"EncDecCTCModel": [set_decoder_num_classes],
"EncDecCTCModelBPE": [bpe_check_inputs_and_version],
"MTEncDecModel": [change_tokenizer_names],
"FastPitchModel": [generate_vocab_mapping, fastpitch_model_versioning],
"RadTTSModel": [generate_vocab_mapping, radtts_model_versioning],
"EncDecFrameClassificationModel": [patch_output_name],
}
19 changes: 19 additions & 0 deletions nemo2riva/patches/frame_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from nemo.core.neural_types.neural_type import NeuralType
from typing import Any, Dict, Optional
from nemo.core.neural_types.elements import LogitsType
from nemo.collections.asr.models import EncDecClassificationModel
from nemo.core.utils.neural_type_utils import get_io_names

def patch_output_name(model, artifacts, **kwargs):
if model.__class__.__name__ == "EncDecFrameClassificationModel":
@property
def output_names(self):
return ["logits"]
model.__class__.output_names = output_names