mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-10 13:07:28 +08:00
Add option to load customized mlx model
This commit is contained in:
54
whisper/convert.py
Normal file
54
whisper/convert.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from whisper.load_models import load_torch_model, torch_to_mlx
|
||||||
|
|
||||||
|
MODEL_DTYPES = {"float16", "float32"}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--torch-name-or-path",
|
||||||
|
type=str,
|
||||||
|
default="tiny",
|
||||||
|
help="The name or path to the PyTorch model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to save the MLX model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
help="The dtype to save the MLX model.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}"
|
||||||
|
dtype = getattr(mx, args.dtype)
|
||||||
|
|
||||||
|
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
|
||||||
|
config = asdict(model.dims)
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
mlx_path = Path(args.mlx_path)
|
||||||
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||||
|
|
||||||
|
# Save config.json with model_type
|
||||||
|
with open(mlx_path / "config.json", "w") as f:
|
||||||
|
config["model_type"] = "whisper"
|
||||||
|
json.dump(config, f, indent=4)
|
@@ -1,14 +1,17 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import torch_whisper, whisper
|
from . import torch_whisper, whisper
|
||||||
@@ -96,7 +99,7 @@ def available_models() -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def load_torch_model(
|
def load_torch_model(
|
||||||
name: str,
|
name_or_path: str,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
) -> torch_whisper.Whisper:
|
) -> torch_whisper.Whisper:
|
||||||
"""
|
"""
|
||||||
@@ -104,8 +107,8 @@ def load_torch_model(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
name : str
|
name_or_path : str
|
||||||
one of the official model names listed by `whisper.available_models()`
|
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint
|
||||||
download_root: str
|
download_root: str
|
||||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
|
||||||
@@ -118,15 +121,14 @@ def load_torch_model(
|
|||||||
if download_root is None:
|
if download_root is None:
|
||||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||||
|
|
||||||
if name in _MODELS:
|
if name_or_path in _MODELS:
|
||||||
checkpoint_file = _download(_MODELS[name], download_root)
|
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
elif not Path(name_or_path).is_file():
|
||||||
else:
|
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
|
||||||
raise RuntimeError(
|
|
||||||
f"Model {name} not found; available models = {available_models()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(checkpoint_file, "rb") as fp:
|
alignment_heads = _ALIGNMENT_HEADS.get(name_or_path)
|
||||||
|
|
||||||
|
with open(name_or_path, "rb") as fp:
|
||||||
checkpoint = torch.load(fp)
|
checkpoint = torch.load(fp)
|
||||||
|
|
||||||
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
||||||
@@ -191,8 +193,48 @@ def torch_to_mlx(
|
|||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name_or_path: str,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
dtype: mx.Dtype = mx.float32,
|
dtype: mx.Dtype = mx.float32,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
return torch_to_mlx(load_torch_model(name, download_root), dtype)
|
if name_or_path in _MODELS:
|
||||||
|
print(f"[INFO] Loading and converting {name_or_path} model")
|
||||||
|
return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype)
|
||||||
|
elif not (glob.glob(f"{name_or_path}/weights*.npz") and glob.glob(f"{name_or_path}/config.json")):
|
||||||
|
raise ValueError(
|
||||||
|
f"{name_or_path} not found in {available_models()}. Ensure that weights*.npz and config.json files are"
|
||||||
|
" present in the specified path"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = Path(name_or_path)
|
||||||
|
|
||||||
|
unsharded_weights_path = model_path / "weights.npz"
|
||||||
|
if unsharded_weights_path.is_file():
|
||||||
|
print(f"[INFO] Loading model from {unsharded_weights_path}")
|
||||||
|
weights = mx.load(str(unsharded_weights_path))
|
||||||
|
else:
|
||||||
|
sharded_weights_glob = str(model_path / "weights.*.npz")
|
||||||
|
weight_files = glob.glob(sharded_weights_glob)
|
||||||
|
print(f"[INFO] Loading model from {sharded_weights_glob}")
|
||||||
|
|
||||||
|
if len(weight_files) == 0:
|
||||||
|
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for wf in weight_files:
|
||||||
|
weights.update(mx.load(wf).items())
|
||||||
|
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
|
model_args = torch_whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
|
model.update(weights)
|
||||||
|
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
return model
|
||||||
|
@@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
|
|||||||
|
|
||||||
class ModelHolder:
|
class ModelHolder:
|
||||||
model = None
|
model = None
|
||||||
model_name = None
|
model_name_or_path = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str, dtype: mx.Dtype):
|
def get_model(cls, model_name_or_path: str, dtype: mx.Dtype):
|
||||||
if cls.model is None or model != cls.model_name:
|
if cls.model is None or model_name_or_path != cls.model_name_or_path:
|
||||||
cls.model = load_model(model, dtype=dtype)
|
cls.model = load_model(model_name_or_path, dtype=dtype)
|
||||||
cls.model_name = model
|
cls.model_name_or_path = model_name_or_path
|
||||||
return cls.model
|
return cls.model
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
audio: Union[str, np.ndarray, mx.array],
|
audio: Union[str, np.ndarray, mx.array],
|
||||||
*,
|
*,
|
||||||
model: str = "tiny",
|
model_name_or_path: str = "tiny",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
@@ -73,8 +73,9 @@ def transcribe(
|
|||||||
audio: Union[str, np.ndarray, mx.array]
|
audio: Union[str, np.ndarray, mx.array]
|
||||||
The path to the audio file to open, or the audio waveform
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
model: str
|
model_name_or_path: str
|
||||||
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
|
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"],
|
||||||
|
or a local folder in MLX format.
|
||||||
Default is "tiny".
|
Default is "tiny".
|
||||||
|
|
||||||
verbose: bool
|
verbose: bool
|
||||||
@@ -115,7 +116,7 @@ def transcribe(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||||
model = ModelHolder.get_model(model, dtype)
|
model = ModelHolder.get_model(model_name_or_path, dtype)
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
Reference in New Issue
Block a user