Add option to load customized mlx model

This commit is contained in:
bofenghuang
2023-12-26 19:05:49 +01:00
parent 1cdbf9e886
commit 43a68ee5e3
3 changed files with 120 additions and 23 deletions

54
whisper/convert.py Normal file
View File

@@ -0,0 +1,54 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from whisper.load_models import load_torch_model, torch_to_mlx
MODEL_DTYPES = {"float16", "float32"}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch-name-or-path",
type=str,
default="tiny",
help="The name or path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="The dtype to save the MLX model.",
)
args = parser.parse_args()
assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}"
dtype = getattr(mx, args.dtype)
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)

View File

@@ -1,14 +1,17 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import glob
import hashlib import hashlib
import json
import os import os
import urllib import urllib
import warnings import warnings
from pathlib import Path
from typing import List from typing import List
import mlx.core as mx import mlx.core as mx
import torch import torch
from mlx.utils import tree_map from mlx.utils import tree_map, tree_unflatten
from tqdm import tqdm from tqdm import tqdm
from . import torch_whisper, whisper from . import torch_whisper, whisper
@@ -96,7 +99,7 @@ def available_models() -> List[str]:
def load_torch_model( def load_torch_model(
name: str, name_or_path: str,
download_root: str = None, download_root: str = None,
) -> torch_whisper.Whisper: ) -> torch_whisper.Whisper:
""" """
@@ -104,8 +107,8 @@ def load_torch_model(
Parameters Parameters
---------- ----------
name : str name_or_path : str
one of the official model names listed by `whisper.available_models()` one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint
download_root: str download_root: str
path to download the model files; by default, it uses "~/.cache/whisper" path to download the model files; by default, it uses "~/.cache/whisper"
@@ -118,15 +121,14 @@ def load_torch_model(
if download_root is None: if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
if name in _MODELS: if name_or_path in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root) name_or_path = _download(_MODELS[name_or_path], download_root)
alignment_heads = _ALIGNMENT_HEADS[name] elif not Path(name_or_path).is_file():
else: raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with open(checkpoint_file, "rb") as fp: alignment_heads = _ALIGNMENT_HEADS.get(name_or_path)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp) checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
@@ -191,8 +193,48 @@ def torch_to_mlx(
def load_model( def load_model(
name: str, name_or_path: str,
download_root: str = None, download_root: str = None,
dtype: mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper: ) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype) if name_or_path in _MODELS:
print(f"[INFO] Loading and converting {name_or_path} model")
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
raise ValueError(
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
" present in the specified path"
)
model_path = Path(name_or_path)
unsharded_weights_path = model_path / "weights.npz"
if unsharded_weights_path.is_file():
print(f"[INFO] Loading model from {unsharded_weights_path}")
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print(f"[INFO] Loading model from {sharded_weights_glob}")
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
model_args = torch_whisper.ModelDimensions(**config)
model = whisper.Whisper(model_args, dtype)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model

View File

@@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
class ModelHolder: class ModelHolder:
model = None model = None
model_name = None model_name_or_path = None
@classmethod @classmethod
def get_model(cls, model: str, dtype: mx.Dtype): def get_model(cls, model_name_or_path: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name: if cls.model is None or model_name_or_path != cls.model_name_or_path:
cls.model = load_model(model, dtype=dtype) cls.model = load_model(model_name_or_path, dtype=dtype)
cls.model_name = model cls.model_name_or_path = model_name_or_path
return cls.model return cls.model
def transcribe( def transcribe(
audio: Union[str, np.ndarray, mx.array], audio: Union[str, np.ndarray, mx.array],
*, *,
model: str = "tiny", model_name_or_path: str = "tiny",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
@@ -73,8 +73,9 @@ def transcribe(
audio: Union[str, np.ndarray, mx.array] audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform The path to the audio file to open, or the audio waveform
model: str model_name_or_path: str
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"]. The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"],
or a local folder in MLX format.
Default is "tiny". Default is "tiny".
verbose: bool verbose: bool
@@ -115,7 +116,7 @@ def transcribe(
""" """
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model, dtype) model = ModelHolder.get_model(model_name_or_path, dtype)
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)