import unittest import mlx.core as mx import numpy as np import os import subprocess import torch import whisper import whisper.audio as audio import whisper.load_models as load_models import whisper.torch_whisper as torch_whisper import whisper.decoding as decoding TEST_AUDIO = "whisper/assets/ls_test.flac" def forward_torch(model, mels, tokens): mels = torch.Tensor(mels).to(torch.float32) tokens = torch.Tensor(tokens).to(torch.int32) with torch.no_grad(): logits = model.forward(mels, tokens) return logits.numpy() def forward_mlx(model, mels, tokens): mels = mx.array(mels.transpose(0, 2, 1)) tokens = mx.array(tokens, mx.int32) logits = model(mels, tokens) return np.array(logits) class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = load_models.load_model("tiny") data = audio.load_audio(TEST_AUDIO) data = audio.pad_or_trim(data) cls.mels = audio.log_mel_spectrogram(data) def test_torch_mlx(self): np.random.seed(10) torch_model = load_models.load_torch_model("tiny") dims = torch_model.dims mels = np.random.randn(1, dims.n_mels, 3_000) tokens = np.random.randint(0, dims.n_vocab, (1, 20)) torch_logits = forward_torch(torch_model, mels, tokens) mlx_model = load_models.torch_to_mlx(torch_model) mlx_logits = forward_mlx(mlx_model, mels, tokens) self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2)) def test_decode_lang(self): options = decoding.DecodingOptions(task="lang_id") result = decoding.decode(self.model, self.mels, options) self.assertEqual(result.language, "en") self.assertEqual(len(result.language_probs), 99) self.assertAlmostEqual( result.language_probs["en"], 0.9947282671928406, places=5 ) def test_decode_greedy(self): result = decoding.decode(self.model, self.mels) self.assertEqual(result.language, "en") self.assertEqual( result.tokens, [ 50364, 1396, 264, 665, 5133, 23109, 25462, 264, 6582, 293, 750, 632, 42841, 292, 370, 938, 294, 4054, 293, 12653, 356, 50620, 50620, 23563, 322, 3312, 13, 50680, ], ) self.assertEqual( result.text, ( "Then the good soul openly sorted the boat and she " "had buoyed so long in secret and bravely stretched on alone." ), ) self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3) self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4) self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) # Small temp should give the same results result = decoding.decode(self.model, self.mels, temperature=1e-8) self.assertEqual( result.text, ( "Then the good soul openly sorted the boat and she " "had buoyed so long in secret and bravely stretched on alone." ), ) self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3) self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4) self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): result = whisper.transcribe(TEST_AUDIO) self.assertEqual( result["text"], ( " Then the good soul openly sorted the boat and she " "had buoyed so long in secret and bravely stretched on alone." ), ) def test_transcribe_alice(self): audio_file = os.path.join( os.path.expanduser("~"), ".cache/whisper/alice.mp3", ) if not os.path.exists(audio_file): print("To run this test download the alice in wonderland audiobook:") print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return result = whisper.transcribe(audio_file) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") self.assertEqual(len(result["segments"]), 77) expected_5 = { "id": 5, "seek": 2800, "start": 40.0, "end": 46.0, "text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.", "tokens": [ 50964, 876, 452, 4716, 707, 3521, 11, 286, 2441, 567, 486, 829, 322, 428, 6654, 293, 4127, 1109, 337, 291, 586, 10462, 13, 51264, ], "temperature": 0.0, "avg_logprob": -0.19670599699020386, "compression_ratio": 1.5991379310344827, "no_speech_prob": 0.09746722131967545, } expected_73 = { "id": 73, "seek": 70700, "start": 707.0, "end": 715.0, "text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.", "tokens": [ 50364, 718, 505, 483, 281, 264, 17805, 11, 293, 550, 286, 603, 980, 291, 452, 2503, 11, 293, 291, 603, 1223, 983, 309, 307, 300, 286, 4700, 11111, 293, 7197, 13, 50764, ], "temperature": 0.0, "avg_logprob": -0.1350895343440594, "compression_ratio": 1.6208333333333333, "no_speech_prob": 0.002246702555567026, } def check_segment(seg, expected): for k, v in expected.items(): if isinstance(v, float): self.assertAlmostEqual(seg[k], v, places=3) else: self.assertEqual(seg[k], v) # Randomly check a couple of segments check_segment(result["segments"][5], expected_5) check_segment(result["segments"][73], expected_73) class TestAudio(unittest.TestCase): def test_load(self): data = audio.load_audio(TEST_AUDIO) data_8k = audio.load_audio(TEST_AUDIO, 8000) n = 106640 self.assertTrue(data.shape, (n,)) self.assertTrue(data.dtype, np.float32) self.assertTrue(data_8k.shape, (n // 2,)) def test_pad(self): data = audio.load_audio(TEST_AUDIO) data = audio.pad_or_trim(data, 20_000) self.assertTrue(data.shape, [20_000]) def test_mel_spec(self): mels = audio.log_mel_spectrogram(TEST_AUDIO) self.assertTrue(mels.shape, [80, 400]) self.assertTrue(mels.dtype, mx.float32) if __name__ == "__main__": unittest.main()