diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 77fa4f63e..a85330b2e 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -977,13 +977,15 @@ def teardown_class(cls): @test_utils.only_on_linux @test_utils.on_available_devices @pytest.mark.parametrize( - "model_name,expected_transcription", + "model_name,expected_transcriptions", [ ( "facebook/wav2vec2-large-robust-ft-swbd-300h", [ "MISTER QUILTER IS THE APOSSEL OF THE MIDDLE CLASSES AND" " WE ARE GLAD TO WELCOME HIS GOSPEL", + "MISTER QUILTER IS THE APOSSTEL OF THE MIDDLE CLASSES AND" + " WE ARE GLAD TO WELCOME HIS GOSPEL", ], ), ], @@ -993,7 +995,7 @@ def test_transformers_wav2vec2( tmp_dir, device, model_name, - expected_transcription, + expected_transcriptions, ): import torch import transformers @@ -1046,7 +1048,7 @@ def test_transformers_wav2vec2( transcription = processor.decode(predicted_ids, output_word_offsets=True) transcription = transcription[0].replace(processor.tokenizer.unk_token, "") - assert transcription == expected_transcription[0] + assert transcription in expected_transcriptions class TestWav2Vec2Bert: @@ -1091,13 +1093,13 @@ def test_transformers_wav2vec2bert( ) device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" - cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + # cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) model = ctranslate2.models.Wav2Vec2Bert( output_dir, device=device, device_index=[0], compute_type="int8", - intra_threads=cpu_threads, + intra_threads=1, inter_threads=1, )