[Whisper] Load customized MLX model & Quantization (#191)

* Add option to load customized mlx model

* Add quantization

* Apply reviews

* Separate model conversion and loading

* Update test

* Fix benchmark

* Add notes about conversion

* Improve doc
This commit is contained in:
bofeng huang 2023-12-29 19:22:15 +01:00 committed by GitHub
parent 1cdbf9e886
commit 581a5733a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 421 additions and 211 deletions

View File

@ -6,7 +6,7 @@ parameters[^1].
### Setup
First, install the dependencies.
First, install the dependencies:
```
pip install -r requirements.txt
@ -19,6 +19,28 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
```
Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format.
To generate a 4-bit quantized model, use `-q`. For a full list of options:
```
python convert.py --help
```
By default, the conversion script will make the directory `mlx_models/tiny` and save
the converted `weights.npz` and `config.json` there.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
> Face and skip the conversion step.
### Run
Transcribe audio with:

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
import argparse
import os
import subprocess
import sys
import time
@ -12,6 +14,12 @@ audio_file = "whisper/assets/ls_test.flac"
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument(
"--mlx-dir",
type=str,
default="mlx_models",
help="The folder of MLX models",
)
parser.add_argument(
"--all",
action="store_true",
@ -57,8 +65,8 @@ def decode(model, mels):
return decoding.decode(model, mels)
def everything(model_name):
return transcribe(audio_file, model=model_name)
def everything(model_path):
return transcribe(audio_file, model_path=model_path)
if __name__ == "__main__":
@ -76,6 +84,11 @@ if __name__ == "__main__":
print(f"\nFeature time {feat_time:.3f}")
for model_name in models:
model_path = f"{args.mlx_dir}/{model_name}"
if not os.path.exists(model_path):
print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion")
subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True)
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(
[
@ -110,12 +123,12 @@ if __name__ == "__main__":
],
mx.int32,
)[None]
model = load_models.load_model(f"{model_name}", dtype=mx.float16)
model = load_models.load_model(model_path, dtype=mx.float16)
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything, model_name)
everything_time = timer(everything, model_path)
print(f"Everything time {everything_time:.3f}")
print(f"\n{'-----' * 10}\n")

284
whisper/convert.py Normal file
View File

@ -0,0 +1,284 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import hashlib
import json
import os
import urllib
import warnings
from dataclasses import asdict
from pathlib import Path
from typing import List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from tqdm import tqdm
from whisper import torch_whisper
from whisper.whisper import ModelDimensions, Whisper
_VALID_DTYPES = {"float16", "float32"}
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_torch_model(
name_or_path: str,
download_root: str = None,
) -> torch_whisper.Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name_or_path : str
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
model = torch_whisper.Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model
def convert(model, rules=None):
params = {}
if rules is not None and type(model) in rules:
out = rules[type(model)](model, rules)
return out
if isinstance(model, torch.Tensor):
return mx.array(model.detach().numpy())
if isinstance(model, torch.nn.ModuleList):
return [convert(n, rules) for n in model.children()]
if isinstance(model, torch.nn.Conv1d):
return {
"weight": convert(model.weight).transpose(0, 2, 1),
"bias": convert(model.bias),
}
for k, n in model.named_children():
if k in rules:
params.update(rules[k](n, rules))
else:
params[k] = convert(n, rules)
for k, p in model.named_parameters(recurse=False):
params[k] = convert(p)
return params
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
mlp = list(children.pop("mlp").children())
params = {
"mlp1": convert(mlp[0], rules),
"mlp2": convert(mlp[-1], rules),
}
for k, n in children.items():
params[k] = convert(n, rules)
return params
rules = {
torch_whisper.ResidualAttentionBlock: convert_rblock,
}
params = convert(torch_model, rules)
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
return mlx_model
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Whisper(ModelDimensions(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
parser.add_argument(
"--torch-name-or-path",
type=str,
default="tiny",
help="The name or path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_models/tiny",
help="The path to save the MLX model.",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="The dtype to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)

View File

@ -1,22 +1,68 @@
# Copyright © 2023 Apple Inc.
import json
import os
import subprocess
import unittest
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from mlx.utils import tree_flatten
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper
from convert import load_torch_model, quantize, torch_to_mlx
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "whisper/assets/ls_test.flac"
def _save_model(save_dir, weights, config):
mlx_path = Path(save_dir)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)
config.pop("model_type", None)
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)
args = type("", (), {})()
args.q_group_size = 64
args.q_bits = 4
weights, config = quantize(weights, config, args)
_save_model(MLX_4BITS_MODEL_PATH, weights, config)
return torch_model, fp32_model, fp16_model
def forward_torch(model, mels, tokens):
mels = torch.Tensor(mels).to(torch.float32)
tokens = torch.Tensor(tokens).to(torch.int32)
@ -35,7 +81,7 @@ def forward_mlx(model, mels, tokens):
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = load_models.load_model("tiny", dtype=mx.float32)
_, cls.model, _ = load_torch_and_mlx()
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
@ -43,7 +89,7 @@ class TestWhisper(unittest.TestCase):
def test_torch_mlx(self):
np.random.seed(10)
torch_model = load_models.load_torch_model("tiny")
torch_model = load_torch_model(MODEL_NAME)
dims = torch_model.dims
mels = np.random.randn(1, dims.n_mels, 3_000)
@ -51,19 +97,27 @@ class TestWhisper(unittest.TestCase):
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
mlx_logits = forward_mlx(mlx_model, mels, tokens)
mlx_logits = forward_mlx(self.model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_fp16(self):
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_quantized_4bits(self):
mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
# Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)
@ -135,7 +189,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, fp16=False)
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
self.assertEqual(
result["text"],
(
@ -154,7 +208,7 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, fp16=False)
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)

View File

@ -1,198 +1,36 @@
# Copyright © 2023 Apple Inc.
import hashlib
import os
import urllib
import warnings
from typing import List
import json
from pathlib import Path
import mlx.core as mx
import torch
from mlx.utils import tree_map
from tqdm import tqdm
import mlx.nn as nn
from mlx.utils import tree_unflatten
from . import torch_whisper, whisper
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_torch_model(
name: str,
download_root: str = None,
) -> torch_whisper.Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name]
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
model = torch_whisper.Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model
def convert(model, rules=None):
params = {}
if rules is not None and type(model) in rules:
out = rules[type(model)](model, rules)
return out
if isinstance(model, torch.Tensor):
return mx.array(model.detach().numpy())
if isinstance(model, torch.nn.ModuleList):
return [convert(n, rules) for n in model.children()]
if isinstance(model, torch.nn.Conv1d):
return {
"weight": convert(model.weight).transpose(0, 2, 1),
"bias": convert(model.bias),
}
for k, n in model.named_children():
if k in rules:
params.update(rules[k](n, rules))
else:
params[k] = convert(n, rules)
for k, p in model.named_parameters(recurse=False):
params[k] = convert(p)
return params
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
mlp = list(children.pop("mlp").children())
params = {
"mlp1": convert(mlp[0], rules),
"mlp2": convert(mlp[-1], rules),
}
for k, n in children.items():
params[k] = convert(n, rules)
return params
rules = {
torch_whisper.ResidualAttentionBlock: convert_rblock,
}
params = convert(torch_model, rules)
mlx_model = whisper.Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
return mlx_model
from . import whisper
def load_model(
name: str,
download_root: str = None,
folder: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype)
model_path = Path(folder)
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model

View File

@ -40,20 +40,20 @@ def _format_timestamp(seconds: float):
class ModelHolder:
model = None
model_name = None
model_path = None
@classmethod
def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype)
cls.model_name = model
def get_model(cls, model_path: str, dtype: mx.Dtype):
if cls.model is None or model_path != cls.model_path:
cls.model = load_model(model_path, dtype=dtype)
cls.model_path = model_path
return cls.model
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
model: str = "tiny",
model_path: str = "mlx_models/tiny",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
@ -73,9 +73,8 @@ def transcribe(
audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform
model: str
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
Default is "tiny".
model_path: str
The path to the Whisper model that has been converted to MLX format.
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
@ -115,7 +114,7 @@ def transcribe(
"""
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model, dtype)
model = ModelHolder.get_model(model_path, dtype)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)