Support Hugging Face models (#215)

* support hf direct models
This commit is contained in:
Awni Hannun
2024-01-03 15:13:26 -08:00
committed by GitHub
parent 1d09c4fecd
commit a5d6d0436c
16 changed files with 654 additions and 27 deletions

View File

@@ -86,8 +86,13 @@ if __name__ == "__main__":
for model_name in models:
model_path = f"{args.mlx_dir}/{model_name}"
if not os.path.exists(model_path):
print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion")
subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True)
print(
f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion"
)
subprocess.run(
f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}",
shell=True,
)
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(

View File

@@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
@@ -132,7 +134,9 @@ def load_torch_model(
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
raise RuntimeError(
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
@@ -259,7 +263,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
assert (
args.dtype in _VALID_DTYPES
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")

View File

@@ -10,6 +10,7 @@ from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
@@ -17,8 +18,6 @@ import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
from convert import load_torch_model, quantize, torch_to_mlx
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
@@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
result["text"],
(
@@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)