mlx-examples/whisper/test.py
2023-12-12 07:37:35 -08:00

282 lines
8.4 KiB
Python

# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import numpy as np
import os
import subprocess
import torch
import whisper
import whisper.audio as audio
import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper
import whisper.decoding as decoding
TEST_AUDIO = "whisper/assets/ls_test.flac"
def forward_torch(model, mels, tokens):
mels = torch.Tensor(mels).to(torch.float32)
tokens = torch.Tensor(tokens).to(torch.int32)
with torch.no_grad():
logits = model.forward(mels, tokens)
return logits.numpy()
def forward_mlx(model, mels, tokens):
mels = mx.array(mels.transpose(0, 2, 1))
tokens = mx.array(tokens, mx.int32)
logits = model(mels, tokens)
return np.array(logits)
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = load_models.load_model("tiny", dtype=mx.float32)
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
def test_torch_mlx(self):
np.random.seed(10)
torch_model = load_models.load_torch_model("tiny")
dims = torch_model.dims
mels = np.random.randn(1, dims.n_mels, 3_000)
tokens = np.random.randint(0, dims.n_vocab, (1, 20))
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
mlx_logits = forward_mlx(mlx_model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_fp16(self):
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)
self.assertEqual(result.language, "en")
self.assertEqual(len(result.language_probs), 99)
self.assertAlmostEqual(
result.language_probs["en"], 0.9947282671928406, places=5
)
def test_decode_greedy(self):
result = decoding.decode(self.model, self.mels, fp16=False)
self.assertEqual(result.language, "en")
self.assertEqual(
result.tokens,
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
# Small temp should give the same results
result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, fp16=False)
self.assertEqual(
result["text"],
(
" Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
def test_transcribe_alice(self):
audio_file = os.path.join(
os.path.expanduser("~"),
".cache/whisper/alice.mp3",
)
if not os.path.exists(audio_file):
print("To run this test download the alice in wonderland audiobook:")
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, fp16=False)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)
expected_5 = {
"id": 5,
"seek": 2800,
"start": 40.0,
"end": 46.0,
"text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.",
"tokens": [
50964,
876,
452,
4716,
707,
3521,
11,
286,
2441,
567,
486,
829,
322,
428,
6654,
293,
4127,
1109,
337,
291,
586,
10462,
13,
51264,
],
"temperature": 0.0,
"avg_logprob": -0.19670599699020386,
"compression_ratio": 1.5991379310344827,
"no_speech_prob": 0.09746722131967545,
}
expected_73 = {
"id": 73,
"seek": 70700,
"start": 707.0,
"end": 715.0,
"text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.",
"tokens": [
50364,
718,
505,
483,
281,
264,
17805,
11,
293,
550,
286,
603,
980,
291,
452,
2503,
11,
293,
291,
603,
1223,
983,
309,
307,
300,
286,
4700,
11111,
293,
7197,
13,
50764,
],
"temperature": 0.0,
"avg_logprob": -0.1350895343440594,
"compression_ratio": 1.6208333333333333,
"no_speech_prob": 0.002246702555567026,
}
def check_segment(seg, expected):
for k, v in expected.items():
if isinstance(v, float):
self.assertAlmostEqual(seg[k], v, places=3)
else:
self.assertEqual(seg[k], v)
# Randomly check a couple of segments
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
class TestAudio(unittest.TestCase):
def test_load(self):
data = audio.load_audio(TEST_AUDIO)
data_8k = audio.load_audio(TEST_AUDIO, 8000)
n = 106640
self.assertTrue(data.shape, (n,))
self.assertTrue(data.dtype, np.float32)
self.assertTrue(data_8k.shape, (n // 2,))
def test_pad(self):
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data, 20_000)
self.assertTrue(data.shape, [20_000])
def test_mel_spec(self):
mels = audio.log_mel_spectrogram(TEST_AUDIO)
self.assertTrue(mels.shape, [80, 400])
self.assertTrue(mels.dtype, mx.float32)
if __name__ == "__main__":
unittest.main()