mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
whisper default in fp16
This commit is contained in:
@@ -36,7 +36,7 @@ def forward_mlx(model, mels, tokens):
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = load_models.load_model("tiny")
|
||||
cls.model = load_models.load_model("tiny", dtype=mx.float32)
|
||||
data = audio.load_audio(TEST_AUDIO)
|
||||
data = audio.pad_or_trim(data)
|
||||
cls.mels = audio.log_mel_spectrogram(data)
|
||||
@@ -52,13 +52,22 @@ class TestWhisper(unittest.TestCase):
|
||||
|
||||
torch_logits = forward_torch(torch_model, mels, tokens)
|
||||
|
||||
mlx_model = load_models.torch_to_mlx(torch_model)
|
||||
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
|
||||
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
||||
|
||||
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
||||
|
||||
def test_fp16(self):
|
||||
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
|
||||
dims = mlx_model.dims
|
||||
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||
logits = mlx_model(mels, tokens)
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
|
||||
def test_decode_lang(self):
|
||||
options = decoding.DecodingOptions(task="lang_id")
|
||||
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||
result = decoding.decode(self.model, self.mels, options)
|
||||
self.assertEqual(result.language, "en")
|
||||
self.assertEqual(len(result.language_probs), 99)
|
||||
@@ -67,7 +76,7 @@ class TestWhisper(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_decode_greedy(self):
|
||||
result = decoding.decode(self.model, self.mels)
|
||||
result = decoding.decode(self.model, self.mels, fp16=False)
|
||||
self.assertEqual(result.language, "en")
|
||||
self.assertEqual(
|
||||
result.tokens,
|
||||
@@ -114,7 +123,7 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
# Small temp should give the same results
|
||||
result = decoding.decode(self.model, self.mels, temperature=1e-8)
|
||||
result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)
|
||||
|
||||
self.assertEqual(
|
||||
result.text,
|
||||
@@ -128,7 +137,7 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(TEST_AUDIO)
|
||||
result = whisper.transcribe(TEST_AUDIO, fp16=False)
|
||||
self.assertEqual(
|
||||
result["text"],
|
||||
(
|
||||
@@ -147,7 +156,7 @@ class TestWhisper(unittest.TestCase):
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(audio_file)
|
||||
result = whisper.transcribe(audio_file, fp16=False)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
self.assertEqual(result["language"], "en")
|
||||
self.assertEqual(len(result["segments"]), 77)
|
||||
|
Reference in New Issue
Block a user