mlx-examples/whisper/test.py
dmdaksh 7d7e236061
- Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00

466 lines
14 KiB
Python

# Copyright © 2023 Apple Inc.
import json
import os
import unittest
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "whisper/assets/ls_test.flac"
def _save_model(save_dir, weights, config):
mlx_path = Path(save_dir)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)
config.pop("model_type", None)
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)
args = type("", (), {})()
args.q_group_size = 64
args.q_bits = 4
weights, config = quantize(weights, config, args)
_save_model(MLX_4BITS_MODEL_PATH, weights, config)
return torch_model, fp32_model, fp16_model
def forward_torch(model, mels, tokens):
mels = torch.Tensor(mels).to(torch.float32)
tokens = torch.Tensor(tokens).to(torch.int32)
with torch.no_grad():
logits = model.forward(mels, tokens)
return logits.numpy()
def forward_mlx(model, mels, tokens):
mels = mx.array(mels.transpose(0, 2, 1))
tokens = mx.array(tokens, mx.int32)
logits = model(mels, tokens)
return np.array(logits)
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
_, cls.model, _ = load_torch_and_mlx()
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
def test_torch_mlx(self):
np.random.seed(10)
torch_model = load_torch_model(MODEL_NAME)
dims = torch_model.dims
mels = np.random.randn(1, dims.n_mels, 3_000)
tokens = np.random.randint(0, dims.n_vocab, (1, 20))
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_logits = forward_mlx(self.model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_fp16(self):
mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_quantized_4bits(self):
mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
# Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)
self.assertEqual(result.language, "en")
self.assertEqual(len(result.language_probs), 99)
self.assertAlmostEqual(
result.language_probs["en"], 0.9947282671928406, places=5
)
def test_decode_greedy(self):
result = decoding.decode(self.model, self.mels, fp16=False)
self.assertEqual(result.language, "en")
self.assertEqual(
result.tokens,
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
# Small temp should give the same results
result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
result["text"],
(
" Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
def test_transcribe_alice(self):
audio_file = os.path.join(
os.path.expanduser("~"),
".cache/whisper/alice.mp3",
)
if not os.path.exists(audio_file):
print("To run this test download the alice in wonderland audiobook:")
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)
expected_5 = {
"id": 5,
"seek": 2800,
"start": 40.0,
"end": 46.0,
"text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.",
"tokens": [
50964,
876,
452,
4716,
707,
3521,
11,
286,
2441,
567,
486,
829,
322,
428,
6654,
293,
4127,
1109,
337,
291,
586,
10462,
13,
51264,
],
"temperature": 0.0,
"avg_logprob": -0.19670599699020386,
"compression_ratio": 1.5991379310344827,
"no_speech_prob": 0.09746722131967545,
}
expected_73 = {
"id": 73,
"seek": 70700,
"start": 707.0,
"end": 715.0,
"text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.",
"tokens": [
50364,
718,
505,
483,
281,
264,
17805,
11,
293,
550,
286,
603,
980,
291,
452,
2503,
11,
293,
291,
603,
1223,
983,
309,
307,
300,
286,
4700,
11111,
293,
7197,
13,
50764,
],
"temperature": 0.0,
"avg_logprob": -0.1350895343440594,
"compression_ratio": 1.6208333333333333,
"no_speech_prob": 0.009053784422576427,
}
def check_segment(seg, expected):
for k, v in expected.items():
if isinstance(v, float):
self.assertAlmostEqual(seg[k], v, places=2)
else:
self.assertEqual(seg[k], v)
# Randomly check a couple of segments
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
TEST_AUDIO,
path_or_hf_repo=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)
# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]
def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)
# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)
class TestAudio(unittest.TestCase):
def test_load(self):
data = audio.load_audio(TEST_AUDIO)
data_8k = audio.load_audio(TEST_AUDIO, 8000)
n = 106640
self.assertTrue(data.shape, (n,))
self.assertTrue(data.dtype, np.float32)
self.assertTrue(data_8k.shape, (n // 2,))
def test_pad(self):
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data, 20_000)
self.assertTrue(data.shape, [20_000])
def test_mel_spec(self):
mels = audio.log_mel_spectrogram(TEST_AUDIO)
self.assertTrue(mels.shape, [80, 400])
self.assertTrue(mels.dtype, mx.float32)
if __name__ == "__main__":
unittest.main()