mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
[Whisper] Load customized MLX model & Quantization (#191)
* Add option to load customized mlx model * Add quantization * Apply reviews * Separate model conversion and loading * Update test * Fix benchmark * Add notes about conversion * Improve doc
This commit is contained in:
parent
1cdbf9e886
commit
581a5733a1
@ -6,7 +6,7 @@ parameters[^1].
|
|||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
First, install the dependencies.
|
First, install the dependencies:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
@ -19,6 +19,28 @@ Install [`ffmpeg`](https://ffmpeg.org/):
|
|||||||
brew install ffmpeg
|
brew install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
||||||
|
```
|
||||||
|
|
||||||
|
Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format.
|
||||||
|
|
||||||
|
To generate a 4-bit quantized model, use `-q`. For a full list of options:
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, the conversion script will make the directory `mlx_models/tiny` and save
|
||||||
|
the converted `weights.npz` and `config.json` there.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Alternatively, you can also download a few converted checkpoints from the
|
||||||
|
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
|
||||||
|
> Face and skip the conversion step.
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
Transcribe audio with:
|
Transcribe audio with:
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -12,6 +14,12 @@ audio_file = "whisper/assets/ls_test.flac"
|
|||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="Benchmark script.")
|
parser = argparse.ArgumentParser(description="Benchmark script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-dir",
|
||||||
|
type=str,
|
||||||
|
default="mlx_models",
|
||||||
|
help="The folder of MLX models",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--all",
|
"--all",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -57,8 +65,8 @@ def decode(model, mels):
|
|||||||
return decoding.decode(model, mels)
|
return decoding.decode(model, mels)
|
||||||
|
|
||||||
|
|
||||||
def everything(model_name):
|
def everything(model_path):
|
||||||
return transcribe(audio_file, model=model_name)
|
return transcribe(audio_file, model_path=model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -76,6 +84,11 @@ if __name__ == "__main__":
|
|||||||
print(f"\nFeature time {feat_time:.3f}")
|
print(f"\nFeature time {feat_time:.3f}")
|
||||||
|
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
|
model_path = f"{args.mlx_dir}/{model_name}"
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion")
|
||||||
|
subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True)
|
||||||
|
|
||||||
print(f"\nModel: {model_name.upper()}")
|
print(f"\nModel: {model_name.upper()}")
|
||||||
tokens = mx.array(
|
tokens = mx.array(
|
||||||
[
|
[
|
||||||
@ -110,12 +123,12 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
mx.int32,
|
mx.int32,
|
||||||
)[None]
|
)[None]
|
||||||
model = load_models.load_model(f"{model_name}", dtype=mx.float16)
|
model = load_models.load_model(model_path, dtype=mx.float16)
|
||||||
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
|
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
|
||||||
model_forward_time = timer(model_forward, model, mels, tokens)
|
model_forward_time = timer(model_forward, model, mels, tokens)
|
||||||
print(f"Model forward time {model_forward_time:.3f}")
|
print(f"Model forward time {model_forward_time:.3f}")
|
||||||
decode_time = timer(decode, model, mels)
|
decode_time = timer(decode, model, mels)
|
||||||
print(f"Decode time {decode_time:.3f}")
|
print(f"Decode time {decode_time:.3f}")
|
||||||
everything_time = timer(everything, model_name)
|
everything_time = timer(everything, model_path)
|
||||||
print(f"Everything time {everything_time:.3f}")
|
print(f"Everything time {everything_time:.3f}")
|
||||||
print(f"\n{'-----' * 10}\n")
|
print(f"\n{'-----' * 10}\n")
|
||||||
|
284
whisper/convert.py
Normal file
284
whisper/convert.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from whisper import torch_whisper
|
||||||
|
from whisper.whisper import ModelDimensions, Whisper
|
||||||
|
|
||||||
|
_VALID_DTYPES = {"float16", "float32"}
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||||
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||||
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||||
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||||
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||||
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
|
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||||
|
_ALIGNMENT_HEADS = {
|
||||||
|
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||||
|
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||||
|
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||||
|
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||||
|
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||||
|
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||||
|
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||||
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||||
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download(url: str, root: str) -> str:
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
with open(download_target, "rb") as f:
|
||||||
|
model_bytes = f.read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
|
return download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(
|
||||||
|
total=int(source.info().get("Content-Length")),
|
||||||
|
ncols=80,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
model_bytes = open(download_target, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
return download_target
|
||||||
|
|
||||||
|
|
||||||
|
def available_models() -> List[str]:
|
||||||
|
"""Returns the names of available models"""
|
||||||
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_model(
|
||||||
|
name_or_path: str,
|
||||||
|
download_root: str = None,
|
||||||
|
) -> torch_whisper.Whisper:
|
||||||
|
"""
|
||||||
|
Load a Whisper ASR model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name_or_path : str
|
||||||
|
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
|
||||||
|
download_root: str
|
||||||
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Whisper
|
||||||
|
The Whisper ASR model instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
if download_root is None:
|
||||||
|
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||||
|
|
||||||
|
# todo: accept alignment_heads of local Pytorch checkpoint
|
||||||
|
alignment_heads = None
|
||||||
|
if name_or_path in _MODELS:
|
||||||
|
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
||||||
|
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||||
|
elif not Path(name_or_path).is_file():
|
||||||
|
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
|
||||||
|
|
||||||
|
with open(name_or_path, "rb") as fp:
|
||||||
|
checkpoint = torch.load(fp)
|
||||||
|
|
||||||
|
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
||||||
|
model = torch_whisper.Whisper(dims)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
if alignment_heads is not None:
|
||||||
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert(model, rules=None):
|
||||||
|
params = {}
|
||||||
|
if rules is not None and type(model) in rules:
|
||||||
|
out = rules[type(model)](model, rules)
|
||||||
|
return out
|
||||||
|
if isinstance(model, torch.Tensor):
|
||||||
|
return mx.array(model.detach().numpy())
|
||||||
|
if isinstance(model, torch.nn.ModuleList):
|
||||||
|
return [convert(n, rules) for n in model.children()]
|
||||||
|
if isinstance(model, torch.nn.Conv1d):
|
||||||
|
return {
|
||||||
|
"weight": convert(model.weight).transpose(0, 2, 1),
|
||||||
|
"bias": convert(model.bias),
|
||||||
|
}
|
||||||
|
for k, n in model.named_children():
|
||||||
|
if k in rules:
|
||||||
|
params.update(rules[k](n, rules))
|
||||||
|
else:
|
||||||
|
params[k] = convert(n, rules)
|
||||||
|
for k, p in model.named_parameters(recurse=False):
|
||||||
|
params[k] = convert(p)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def torch_to_mlx(
|
||||||
|
torch_model: torch_whisper.Whisper,
|
||||||
|
dtype: mx.Dtype = mx.float16,
|
||||||
|
) -> Whisper:
|
||||||
|
def convert_rblock(model, rules):
|
||||||
|
children = dict(model.named_children())
|
||||||
|
mlp = list(children.pop("mlp").children())
|
||||||
|
params = {
|
||||||
|
"mlp1": convert(mlp[0], rules),
|
||||||
|
"mlp2": convert(mlp[-1], rules),
|
||||||
|
}
|
||||||
|
for k, n in children.items():
|
||||||
|
params[k] = convert(n, rules)
|
||||||
|
return params
|
||||||
|
|
||||||
|
rules = {
|
||||||
|
torch_whisper.ResidualAttentionBlock: convert_rblock,
|
||||||
|
}
|
||||||
|
|
||||||
|
params = convert(torch_model, rules)
|
||||||
|
|
||||||
|
mlx_model = Whisper(torch_model.dims, dtype)
|
||||||
|
params = tree_map(lambda p: p.astype(dtype), params)
|
||||||
|
mlx_model.update(params)
|
||||||
|
return mlx_model
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(weights, config, args):
|
||||||
|
quantized_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Load the model:
|
||||||
|
model = Whisper(ModelDimensions(**config))
|
||||||
|
weights = tree_map(mx.array, weights)
|
||||||
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
|
||||||
|
# Quantize the model:
|
||||||
|
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||||
|
|
||||||
|
# Update the config:
|
||||||
|
quantized_config["quantization"] = {
|
||||||
|
"group_size": args.q_group_size,
|
||||||
|
"bits": args.q_bits,
|
||||||
|
}
|
||||||
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--torch-name-or-path",
|
||||||
|
type=str,
|
||||||
|
default="tiny",
|
||||||
|
help="The name or path to the PyTorch model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_models/tiny",
|
||||||
|
help="The path to save the MLX model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
help="The dtype to save the MLX model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-q",
|
||||||
|
"--quantize",
|
||||||
|
help="Generate a quantized model.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q_group_size",
|
||||||
|
help="Group size for quantization.",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q_bits",
|
||||||
|
help="Bits per weight for quantization.",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
||||||
|
dtype = getattr(mx, args.dtype)
|
||||||
|
|
||||||
|
print("[INFO] Loading")
|
||||||
|
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
|
||||||
|
config = asdict(model.dims)
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
if args.quantize:
|
||||||
|
print("[INFO] Quantizing")
|
||||||
|
weights, config = quantize(weights, config, args)
|
||||||
|
|
||||||
|
mlx_path = Path(args.mlx_path)
|
||||||
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
print("[INFO] Saving")
|
||||||
|
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||||
|
|
||||||
|
# Save config.json with model_type
|
||||||
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
|
config["model_type"] = "whisper"
|
||||||
|
json.dump(config, f, indent=4)
|
@ -1,22 +1,68 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import unittest
|
import unittest
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
import whisper.audio as audio
|
import whisper.audio as audio
|
||||||
import whisper.decoding as decoding
|
import whisper.decoding as decoding
|
||||||
import whisper.load_models as load_models
|
import whisper.load_models as load_models
|
||||||
import whisper.torch_whisper as torch_whisper
|
|
||||||
|
|
||||||
|
from convert import load_torch_model, quantize, torch_to_mlx
|
||||||
|
|
||||||
|
MODEL_NAME = "tiny"
|
||||||
|
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||||
|
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||||
|
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
|
||||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
|
|
||||||
|
def _save_model(save_dir, weights, config):
|
||||||
|
mlx_path = Path(save_dir)
|
||||||
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||||
|
|
||||||
|
# Save config.json with model_type
|
||||||
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
|
config["model_type"] = "whisper"
|
||||||
|
json.dump(config, f, indent=4)
|
||||||
|
|
||||||
|
config.pop("model_type", None)
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_and_mlx():
|
||||||
|
torch_model = load_torch_model(MODEL_NAME)
|
||||||
|
|
||||||
|
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
|
||||||
|
config = asdict(fp32_model.dims)
|
||||||
|
weights = dict(tree_flatten(fp32_model.parameters()))
|
||||||
|
_save_model(MLX_FP32_MODEL_PATH, weights, config)
|
||||||
|
|
||||||
|
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
|
||||||
|
config = asdict(fp16_model.dims)
|
||||||
|
weights = dict(tree_flatten(fp16_model.parameters()))
|
||||||
|
_save_model(MLX_FP16_MODEL_PATH, weights, config)
|
||||||
|
|
||||||
|
args = type("", (), {})()
|
||||||
|
args.q_group_size = 64
|
||||||
|
args.q_bits = 4
|
||||||
|
weights, config = quantize(weights, config, args)
|
||||||
|
_save_model(MLX_4BITS_MODEL_PATH, weights, config)
|
||||||
|
|
||||||
|
return torch_model, fp32_model, fp16_model
|
||||||
|
|
||||||
|
|
||||||
def forward_torch(model, mels, tokens):
|
def forward_torch(model, mels, tokens):
|
||||||
mels = torch.Tensor(mels).to(torch.float32)
|
mels = torch.Tensor(mels).to(torch.float32)
|
||||||
tokens = torch.Tensor(tokens).to(torch.int32)
|
tokens = torch.Tensor(tokens).to(torch.int32)
|
||||||
@ -35,7 +81,7 @@ def forward_mlx(model, mels, tokens):
|
|||||||
class TestWhisper(unittest.TestCase):
|
class TestWhisper(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = load_models.load_model("tiny", dtype=mx.float32)
|
_, cls.model, _ = load_torch_and_mlx()
|
||||||
data = audio.load_audio(TEST_AUDIO)
|
data = audio.load_audio(TEST_AUDIO)
|
||||||
data = audio.pad_or_trim(data)
|
data = audio.pad_or_trim(data)
|
||||||
cls.mels = audio.log_mel_spectrogram(data)
|
cls.mels = audio.log_mel_spectrogram(data)
|
||||||
@ -43,7 +89,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
def test_torch_mlx(self):
|
def test_torch_mlx(self):
|
||||||
np.random.seed(10)
|
np.random.seed(10)
|
||||||
|
|
||||||
torch_model = load_models.load_torch_model("tiny")
|
torch_model = load_torch_model(MODEL_NAME)
|
||||||
dims = torch_model.dims
|
dims = torch_model.dims
|
||||||
|
|
||||||
mels = np.random.randn(1, dims.n_mels, 3_000)
|
mels = np.random.randn(1, dims.n_mels, 3_000)
|
||||||
@ -51,19 +97,27 @@ class TestWhisper(unittest.TestCase):
|
|||||||
|
|
||||||
torch_logits = forward_torch(torch_model, mels, tokens)
|
torch_logits = forward_torch(torch_model, mels, tokens)
|
||||||
|
|
||||||
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
|
mlx_logits = forward_mlx(self.model, mels, tokens)
|
||||||
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
|
||||||
|
|
||||||
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
||||||
|
|
||||||
def test_fp16(self):
|
def test_fp16(self):
|
||||||
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
|
mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
|
||||||
dims = mlx_model.dims
|
dims = mlx_model.dims
|
||||||
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||||
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||||
logits = mlx_model(mels, tokens)
|
logits = mlx_model(mels, tokens)
|
||||||
self.assertEqual(logits.dtype, mx.float16)
|
self.assertEqual(logits.dtype, mx.float16)
|
||||||
|
|
||||||
|
def test_quantized_4bits(self):
|
||||||
|
mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
|
||||||
|
dims = mlx_model.dims
|
||||||
|
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||||
|
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||||
|
logits = mlx_model(mels, tokens)
|
||||||
|
# Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
|
||||||
|
self.assertEqual(logits.dtype, mx.float16)
|
||||||
|
|
||||||
def test_decode_lang(self):
|
def test_decode_lang(self):
|
||||||
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||||
result = decoding.decode(self.model, self.mels, options)
|
result = decoding.decode(self.model, self.mels, options)
|
||||||
@ -135,7 +189,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||||
|
|
||||||
def test_transcribe(self):
|
def test_transcribe(self):
|
||||||
result = whisper.transcribe(TEST_AUDIO, fp16=False)
|
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result["text"],
|
result["text"],
|
||||||
(
|
(
|
||||||
@ -154,7 +208,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||||
return
|
return
|
||||||
|
|
||||||
result = whisper.transcribe(audio_file, fp16=False)
|
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||||
self.assertEqual(len(result["text"]), 10920)
|
self.assertEqual(len(result["text"]), 10920)
|
||||||
self.assertEqual(result["language"], "en")
|
self.assertEqual(result["language"], "en")
|
||||||
self.assertEqual(len(result["segments"]), 77)
|
self.assertEqual(len(result["segments"]), 77)
|
||||||
|
@ -1,198 +1,36 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import hashlib
|
import json
|
||||||
import os
|
from pathlib import Path
|
||||||
import urllib
|
|
||||||
import warnings
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_unflatten
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from . import torch_whisper, whisper
|
from . import whisper
|
||||||
|
|
||||||
_MODELS = {
|
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
|
||||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
|
||||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
|
||||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
|
||||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
|
||||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
|
||||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
|
||||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
|
||||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
|
||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
|
||||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
|
||||||
}
|
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
|
||||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
|
||||||
_ALIGNMENT_HEADS = {
|
|
||||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
|
||||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
|
||||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
|
||||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
|
||||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
|
||||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
|
||||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
|
||||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
|
||||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
|
||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str) -> str:
|
|
||||||
os.makedirs(root, exist_ok=True)
|
|
||||||
|
|
||||||
expected_sha256 = url.split("/")[-2]
|
|
||||||
download_target = os.path.join(root, os.path.basename(url))
|
|
||||||
|
|
||||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
|
||||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
|
||||||
|
|
||||||
if os.path.isfile(download_target):
|
|
||||||
with open(download_target, "rb") as f:
|
|
||||||
model_bytes = f.read()
|
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
|
||||||
return download_target
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
|
||||||
)
|
|
||||||
|
|
||||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
|
||||||
with tqdm(
|
|
||||||
total=int(source.info().get("Content-Length")),
|
|
||||||
ncols=80,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as loop:
|
|
||||||
while True:
|
|
||||||
buffer = source.read(8192)
|
|
||||||
if not buffer:
|
|
||||||
break
|
|
||||||
|
|
||||||
output.write(buffer)
|
|
||||||
loop.update(len(buffer))
|
|
||||||
|
|
||||||
model_bytes = open(download_target, "rb").read()
|
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
|
||||||
)
|
|
||||||
|
|
||||||
return download_target
|
|
||||||
|
|
||||||
|
|
||||||
def available_models() -> List[str]:
|
|
||||||
"""Returns the names of available models"""
|
|
||||||
return list(_MODELS.keys())
|
|
||||||
|
|
||||||
|
|
||||||
def load_torch_model(
|
|
||||||
name: str,
|
|
||||||
download_root: str = None,
|
|
||||||
) -> torch_whisper.Whisper:
|
|
||||||
"""
|
|
||||||
Load a Whisper ASR model
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
one of the official model names listed by `whisper.available_models()`
|
|
||||||
download_root: str
|
|
||||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
model : Whisper
|
|
||||||
The Whisper ASR model instance
|
|
||||||
"""
|
|
||||||
|
|
||||||
if download_root is None:
|
|
||||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
|
||||||
|
|
||||||
if name in _MODELS:
|
|
||||||
checkpoint_file = _download(_MODELS[name], download_root)
|
|
||||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model {name} not found; available models = {available_models()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(checkpoint_file, "rb") as fp:
|
|
||||||
checkpoint = torch.load(fp)
|
|
||||||
|
|
||||||
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
|
||||||
model = torch_whisper.Whisper(dims)
|
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
|
||||||
|
|
||||||
if alignment_heads is not None:
|
|
||||||
model.set_alignment_heads(alignment_heads)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def convert(model, rules=None):
|
|
||||||
params = {}
|
|
||||||
if rules is not None and type(model) in rules:
|
|
||||||
out = rules[type(model)](model, rules)
|
|
||||||
return out
|
|
||||||
if isinstance(model, torch.Tensor):
|
|
||||||
return mx.array(model.detach().numpy())
|
|
||||||
if isinstance(model, torch.nn.ModuleList):
|
|
||||||
return [convert(n, rules) for n in model.children()]
|
|
||||||
if isinstance(model, torch.nn.Conv1d):
|
|
||||||
return {
|
|
||||||
"weight": convert(model.weight).transpose(0, 2, 1),
|
|
||||||
"bias": convert(model.bias),
|
|
||||||
}
|
|
||||||
for k, n in model.named_children():
|
|
||||||
if k in rules:
|
|
||||||
params.update(rules[k](n, rules))
|
|
||||||
else:
|
|
||||||
params[k] = convert(n, rules)
|
|
||||||
for k, p in model.named_parameters(recurse=False):
|
|
||||||
params[k] = convert(p)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def torch_to_mlx(
|
|
||||||
torch_model: torch_whisper.Whisper,
|
|
||||||
dtype: mx.Dtype = mx.float16,
|
|
||||||
) -> whisper.Whisper:
|
|
||||||
def convert_rblock(model, rules):
|
|
||||||
children = dict(model.named_children())
|
|
||||||
mlp = list(children.pop("mlp").children())
|
|
||||||
params = {
|
|
||||||
"mlp1": convert(mlp[0], rules),
|
|
||||||
"mlp2": convert(mlp[-1], rules),
|
|
||||||
}
|
|
||||||
for k, n in children.items():
|
|
||||||
params[k] = convert(n, rules)
|
|
||||||
return params
|
|
||||||
|
|
||||||
rules = {
|
|
||||||
torch_whisper.ResidualAttentionBlock: convert_rblock,
|
|
||||||
}
|
|
||||||
|
|
||||||
params = convert(torch_model, rules)
|
|
||||||
|
|
||||||
mlx_model = whisper.Whisper(torch_model.dims, dtype)
|
|
||||||
params = tree_map(lambda p: p.astype(dtype), params)
|
|
||||||
mlx_model.update(params)
|
|
||||||
return mlx_model
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
folder: str,
|
||||||
download_root: str = None,
|
|
||||||
dtype: mx.Dtype = mx.float32,
|
dtype: mx.Dtype = mx.float32,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
return torch_to_mlx(load_torch_model(name, download_root), dtype)
|
model_path = Path(folder)
|
||||||
|
|
||||||
|
with open(str(model_path / "config.json"), "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
|
quantization = config.pop("quantization", None)
|
||||||
|
|
||||||
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
|
||||||
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
if quantization is not None:
|
||||||
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
|
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
|
|||||||
|
|
||||||
class ModelHolder:
|
class ModelHolder:
|
||||||
model = None
|
model = None
|
||||||
model_name = None
|
model_path = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str, dtype: mx.Dtype):
|
def get_model(cls, model_path: str, dtype: mx.Dtype):
|
||||||
if cls.model is None or model != cls.model_name:
|
if cls.model is None or model_path != cls.model_path:
|
||||||
cls.model = load_model(model, dtype=dtype)
|
cls.model = load_model(model_path, dtype=dtype)
|
||||||
cls.model_name = model
|
cls.model_path = model_path
|
||||||
return cls.model
|
return cls.model
|
||||||
|
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
audio: Union[str, np.ndarray, mx.array],
|
audio: Union[str, np.ndarray, mx.array],
|
||||||
*,
|
*,
|
||||||
model: str = "tiny",
|
model_path: str = "mlx_models/tiny",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
@ -73,9 +73,8 @@ def transcribe(
|
|||||||
audio: Union[str, np.ndarray, mx.array]
|
audio: Union[str, np.ndarray, mx.array]
|
||||||
The path to the audio file to open, or the audio waveform
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
model: str
|
model_path: str
|
||||||
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
|
The path to the Whisper model that has been converted to MLX format.
|
||||||
Default is "tiny".
|
|
||||||
|
|
||||||
verbose: bool
|
verbose: bool
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||||
@ -115,7 +114,7 @@ def transcribe(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||||
model = ModelHolder.get_model(model, dtype)
|
model = ModelHolder.get_model(model_path, dtype)
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
Loading…
Reference in New Issue
Block a user