From 43a68ee5e30365d4a53998ff4afc452be2bb195a Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Tue, 26 Dec 2023 19:05:49 +0100 Subject: [PATCH] Add option to load customized mlx model --- whisper/convert.py | 54 ++++++++++++++++++++++++++ whisper/whisper/load_models.py | 70 +++++++++++++++++++++++++++------- whisper/whisper/transcribe.py | 19 ++++----- 3 files changed, 120 insertions(+), 23 deletions(-) create mode 100644 whisper/convert.py diff --git a/whisper/convert.py b/whisper/convert.py new file mode 100644 index 00000000..6cad98b8 --- /dev/null +++ b/whisper/convert.py @@ -0,0 +1,54 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import mlx.core as mx +import numpy as np +from mlx.utils import tree_flatten + +from whisper.load_models import load_torch_model, torch_to_mlx + +MODEL_DTYPES = {"float16", "float32"} + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") + parser.add_argument( + "--torch-name-or-path", + type=str, + default="tiny", + help="The name or path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "--dtype", + type=str, + default="float16", + help="The dtype to save the MLX model.", + ) + args = parser.parse_args() + + assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}" + dtype = getattr(mx, args.dtype) + + model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype) + config = asdict(model.dims) + weights = dict(tree_flatten(model.parameters())) + + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + # Save weights + np.savez(str(mlx_path / "weights.npz"), **weights) + + # Save config.json with model_type + with open(mlx_path / "config.json", "w") as f: + config["model_type"] = "whisper" + json.dump(config, f, indent=4) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 32e62119..0aa29f11 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -1,14 +1,17 @@ # Copyright © 2023 Apple Inc. +import glob import hashlib +import json import os import urllib import warnings +from pathlib import Path from typing import List import mlx.core as mx import torch -from mlx.utils import tree_map +from mlx.utils import tree_map, tree_unflatten from tqdm import tqdm from . import torch_whisper, whisper @@ -96,7 +99,7 @@ def available_models() -> List[str]: def load_torch_model( - name: str, + name_or_path: str, download_root: str = None, ) -> torch_whisper.Whisper: """ @@ -104,8 +107,8 @@ def load_torch_model( Parameters ---------- - name : str - one of the official model names listed by `whisper.available_models()` + name_or_path : str + one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint download_root: str path to download the model files; by default, it uses "~/.cache/whisper" @@ -118,15 +121,14 @@ def load_torch_model( if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root) - alignment_heads = _ALIGNMENT_HEADS[name] - else: - raise RuntimeError( - f"Model {name} not found; available models = {available_models()}" - ) + if name_or_path in _MODELS: + name_or_path = _download(_MODELS[name_or_path], download_root) + elif not Path(name_or_path).is_file(): + raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path") - with open(checkpoint_file, "rb") as fp: + alignment_heads = _ALIGNMENT_HEADS.get(name_or_path) + + with open(name_or_path, "rb") as fp: checkpoint = torch.load(fp) dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) @@ -191,8 +193,48 @@ def torch_to_mlx( def load_model( - name: str, + name_or_path: str, download_root: str = None, dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: - return torch_to_mlx(load_torch_model(name, download_root), dtype) + if name_or_path in _MODELS: + print(f"[INFO] Loading and converting {name_or_path} model") + return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype) + elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")): + raise ValueError( + f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are" + " present in the specified path" + ) + + model_path = Path(name_or_path) + + unsharded_weights_path = model_path / "weights.npz" + if unsharded_weights_path.is_file(): + print(f"[INFO] Loading model from {unsharded_weights_path}") + weights = mx.load(str(unsharded_weights_path)) + else: + sharded_weights_glob = str(model_path / "weights.*.npz") + weight_files = glob.glob(sharded_weights_glob) + print(f"[INFO] Loading model from {sharded_weights_glob}") + + if len(weight_files) == 0: + raise FileNotFoundError("No weights found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + config.pop("model_type", None) + model_args = torch_whisper.ModelDimensions(**config) + + model = whisper.Whisper(model_args, dtype) + + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model.update(weights) + + mx.eval(model.parameters()) + + return model diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 67232136..03e72963 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -40,20 +40,20 @@ def _format_timestamp(seconds: float): class ModelHolder: model = None - model_name = None + model_name_or_path = None @classmethod - def get_model(cls, model: str, dtype: mx.Dtype): - if cls.model is None or model != cls.model_name: - cls.model = load_model(model, dtype=dtype) - cls.model_name = model + def get_model(cls, model_name_or_path: str, dtype: mx.Dtype): + if cls.model is None or model_name_or_path != cls.model_name_or_path: + cls.model = load_model(model_name_or_path, dtype=dtype) + cls.model_name_or_path = model_name_or_path return cls.model def transcribe( audio: Union[str, np.ndarray, mx.array], *, - model: str = "tiny", + model_name_or_path: str = "tiny", verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, @@ -73,8 +73,9 @@ def transcribe( audio: Union[str, np.ndarray, mx.array] The path to the audio file to open, or the audio waveform - model: str - The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"]. + model_name_or_path: str + The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"], + or a local folder in MLX format. Default is "tiny". verbose: bool @@ -115,7 +116,7 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - model = ModelHolder.get_model(model, dtype) + model = ModelHolder.get_model(model_name_or_path, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)