mlx-examples/whisper/mlx_whisper/load_models.py
madroid 6775d6cb3f
Whisper: Add pip distribution configuration to support pip installations. (#739)
* Whisper: rename whisper to mlx_whisper

* Whisper: add setup.py config for publish

* Whisper: add assets data to setup config

* Whisper: pre-commit for setup.py

* Whisper: Update README.md

* Whisper: Update README.md

* nits

* fix package data

* nit in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-01 09:00:02 -07:00

44 lines
1.2 KiB
Python

# Copyright © 2023 Apple Inc.
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from . import whisper
def load_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model