mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Add option to load customized mlx model
This commit is contained in:
parent
1cdbf9e886
commit
43a68ee5e3
54
whisper/convert.py
Normal file
54
whisper/convert.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from whisper.load_models import load_torch_model, torch_to_mlx
|
||||
|
||||
MODEL_DTYPES = {"float16", "float32"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--torch-name-or-path",
|
||||
type=str,
|
||||
default="tiny",
|
||||
help="The name or path to the PyTorch model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
help="The dtype to save the MLX model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}"
|
||||
dtype = getattr(mx, args.dtype)
|
||||
|
||||
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
|
||||
config = asdict(model.dims)
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
config["model_type"] = "whisper"
|
||||
json.dump(config, f, indent=4)
|
@ -1,14 +1,17 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from mlx.utils import tree_map
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import torch_whisper, whisper
|
||||
@ -96,7 +99,7 @@ def available_models() -> List[str]:
|
||||
|
||||
|
||||
def load_torch_model(
|
||||
name: str,
|
||||
name_or_path: str,
|
||||
download_root: str = None,
|
||||
) -> torch_whisper.Whisper:
|
||||
"""
|
||||
@ -104,8 +107,8 @@ def load_torch_model(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`
|
||||
name_or_path : str
|
||||
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
|
||||
@ -118,15 +121,14 @@ def load_torch_model(
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
if name_or_path in _MODELS:
|
||||
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||
elif not Path(name_or_path).is_file():
|
||||
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
|
||||
|
||||
with open(checkpoint_file, "rb") as fp:
|
||||
alignment_heads = _ALIGNMENT_HEADS.get(name_or_path)
|
||||
|
||||
with open(name_or_path, "rb") as fp:
|
||||
checkpoint = torch.load(fp)
|
||||
|
||||
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
||||
@ -191,8 +193,48 @@ def torch_to_mlx(
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
name_or_path: str,
|
||||
download_root: str = None,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
return torch_to_mlx(load_torch_model(name, download_root), dtype)
|
||||
if name_or_path in _MODELS:
|
||||
print(f"[INFO] Loading and converting {name_or_path} model")
|
||||
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
|
||||
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
|
||||
raise ValueError(
|
||||
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
|
||||
" present in the specified path"
|
||||
)
|
||||
|
||||
model_path = Path(name_or_path)
|
||||
|
||||
unsharded_weights_path = model_path / "weights.npz"
|
||||
if unsharded_weights_path.is_file():
|
||||
print(f"[INFO] Loading model from {unsharded_weights_path}")
|
||||
weights = mx.load(str(unsharded_weights_path))
|
||||
else:
|
||||
sharded_weights_glob = str(model_path / "weights.*.npz")
|
||||
weight_files = glob.glob(sharded_weights_glob)
|
||||
print(f"[INFO] Loading model from {sharded_weights_glob}")
|
||||
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
model_args = torch_whisper.ModelDimensions(**config)
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
return model
|
||||
|
@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
|
||||
|
||||
class ModelHolder:
|
||||
model = None
|
||||
model_name = None
|
||||
model_name_or_path = None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, dtype: mx.Dtype):
|
||||
if cls.model is None or model != cls.model_name:
|
||||
cls.model = load_model(model, dtype=dtype)
|
||||
cls.model_name = model
|
||||
def get_model(cls, model_name_or_path: str, dtype: mx.Dtype):
|
||||
if cls.model is None or model_name_or_path != cls.model_name_or_path:
|
||||
cls.model = load_model(model_name_or_path, dtype=dtype)
|
||||
cls.model_name_or_path = model_name_or_path
|
||||
return cls.model
|
||||
|
||||
|
||||
def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array],
|
||||
*,
|
||||
model: str = "tiny",
|
||||
model_name_or_path: str = "tiny",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
@ -73,8 +73,9 @@ def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
model: str
|
||||
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
|
||||
model_name_or_path: str
|
||||
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"],
|
||||
or a local folder in MLX format.
|
||||
Default is "tiny".
|
||||
|
||||
verbose: bool
|
||||
@ -115,7 +116,7 @@ def transcribe(
|
||||
"""
|
||||
|
||||
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||
model = ModelHolder.get_model(model, dtype)
|
||||
model = ModelHolder.get_model(model_name_or_path, dtype)
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||
|
Loading…
Reference in New Issue
Block a user