mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
@@ -86,8 +86,13 @@ if __name__ == "__main__":
|
||||
for model_name in models:
|
||||
model_path = f"{args.mlx_dir}/{model_name}"
|
||||
if not os.path.exists(model_path):
|
||||
print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion")
|
||||
subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True)
|
||||
print(
|
||||
f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion"
|
||||
)
|
||||
subprocess.run(
|
||||
f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
print(f"\nModel: {model_name.upper()}")
|
||||
tokens = mx.array(
|
||||
|
@@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str:
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
@@ -132,7 +134,9 @@ def load_torch_model(
|
||||
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
||||
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||
elif not Path(name_or_path).is_file():
|
||||
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
|
||||
raise RuntimeError(
|
||||
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
|
||||
)
|
||||
|
||||
with open(name_or_path, "rb") as fp:
|
||||
checkpoint = torch.load(fp)
|
||||
@@ -259,7 +263,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
||||
assert (
|
||||
args.dtype in _VALID_DTYPES
|
||||
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
||||
dtype = getattr(mx, args.dtype)
|
||||
|
||||
print("[INFO] Loading")
|
||||
|
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import whisper
|
||||
@@ -17,8 +18,6 @@ import whisper.audio as audio
|
||||
import whisper.decoding as decoding
|
||||
import whisper.load_models as load_models
|
||||
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||
@@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
result = whisper.transcribe(
|
||||
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(
|
||||
result["text"],
|
||||
(
|
||||
@@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase):
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
result = whisper.transcribe(
|
||||
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
self.assertEqual(result["language"], "en")
|
||||
self.assertEqual(len(result["segments"]), 77)
|
||||
|
Reference in New Issue
Block a user