Support Hugging Face models (#215)

* support hf direct models
This commit is contained in:
Awni Hannun
2024-01-03 15:13:26 -08:00
committed by GitHub
parent 1d09c4fecd
commit a5d6d0436c
16 changed files with 654 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
@@ -17,8 +18,6 @@ import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
from convert import load_torch_model, quantize, torch_to_mlx
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
@@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
result["text"],
(
@@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)