mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-10 21:37:45 +08:00
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import whisper
|
||||
@@ -17,8 +18,6 @@ import whisper.audio as audio
|
||||
import whisper.decoding as decoding
|
||||
import whisper.load_models as load_models
|
||||
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||
@@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
result = whisper.transcribe(
|
||||
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(
|
||||
result["text"],
|
||||
(
|
||||
@@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase):
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
result = whisper.transcribe(
|
||||
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
self.assertEqual(result["language"], "en")
|
||||
self.assertEqual(len(result["segments"]), 77)
|
||||
|
Reference in New Issue
Block a user