mlx-examples/whisper/mlx_whisper/load_models.py

47 lines
1.2 KiB
Python
Raw Permalink Normal View History

2023-12-01 03:08:53 +08:00
# Copyright © 2023 Apple Inc.
import json
from pathlib import Path
2023-11-30 00:17:26 +08:00
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from . import whisper
2023-11-30 00:17:26 +08:00
def load_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
2023-11-30 00:17:26 +08:00
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
2023-11-30 00:17:26 +08:00
model_args = whisper.ModelDimensions(**config)
2023-11-30 00:17:26 +08:00
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
2023-11-30 00:17:26 +08:00
model = whisper.Whisper(model_args, dtype)
2023-11-30 00:17:26 +08:00
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
2023-11-30 00:17:26 +08:00
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
2023-11-30 00:17:26 +08:00
return model