Add option to load customized mlx model

This commit is contained in:
bofenghuang 2023-12-26 19:05:49 +01:00
parent 1cdbf9e886
commit 43a68ee5e3
3 changed files with 120 additions and 23 deletions

54
whisper/convert.py Normal file
View File

@ -0,0 +1,54 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from whisper.load_models import load_torch_model, torch_to_mlx
MODEL_DTYPES = {"float16", "float32"}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch-name-or-path",
type=str,
default="tiny",
help="The name or path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="The dtype to save the MLX model.",
)
args = parser.parse_args()
assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}"
dtype = getattr(mx, args.dtype)
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)

View File

@ -1,14 +1,17 @@
# Copyright © 2023 Apple Inc.
import glob
import hashlib
import json
import os
import urllib
import warnings
from pathlib import Path
from typing import List
import mlx.core as mx
import torch
from mlx.utils import tree_map
from mlx.utils import tree_map, tree_unflatten
from tqdm import tqdm
from . import torch_whisper, whisper
@ -96,7 +99,7 @@ def available_models() -> List[str]:
def load_torch_model(
name: str,
name_or_path: str,
download_root: str = None,
) -> torch_whisper.Whisper:
"""
@ -104,8 +107,8 @@ def load_torch_model(
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`
name_or_path : str
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
@ -118,15 +121,14 @@ def load_torch_model(
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name]
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
if name_or_path in _MODELS:
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
with open(checkpoint_file, "rb") as fp:
alignment_heads = _ALIGNMENT_HEADS.get(name_or_path)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
@ -191,8 +193,48 @@ def torch_to_mlx(
def load_model(
name: str,
name_or_path: str,
download_root: str = None,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype)
if name_or_path in _MODELS:
print(f"[INFO] Loading and converting {name_or_path} model")
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
raise ValueError(
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
" present in the specified path"
)
model_path = Path(name_or_path)
unsharded_weights_path = model_path / "weights.npz"
if unsharded_weights_path.is_file():
print(f"[INFO] Loading model from {unsharded_weights_path}")
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print(f"[INFO] Loading model from {sharded_weights_glob}")
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
model_args = torch_whisper.ModelDimensions(**config)
model = whisper.Whisper(model_args, dtype)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model

View File

@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
class ModelHolder:
model = None
model_name = None
model_name_or_path = None
@classmethod
def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype)
cls.model_name = model
def get_model(cls, model_name_or_path: str, dtype: mx.Dtype):
if cls.model is None or model_name_or_path != cls.model_name_or_path:
cls.model = load_model(model_name_or_path, dtype=dtype)
cls.model_name_or_path = model_name_or_path
return cls.model
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
model: str = "tiny",
model_name_or_path: str = "tiny",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
@ -73,8 +73,9 @@ def transcribe(
audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform
model: str
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
model_name_or_path: str
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"],
or a local folder in MLX format.
Default is "tiny".
verbose: bool
@ -115,7 +116,7 @@ def transcribe(
"""
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model, dtype)
model = ModelHolder.get_model(model_name_or_path, dtype)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)