diff --git a/whisper/README.md b/whisper/README.md index 7df1382f..50fc0764 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -6,7 +6,7 @@ parameters[^1]. ### Setup -First, install the dependencies. +First, install the dependencies: ``` pip install -r requirements.txt @@ -19,6 +19,28 @@ Install [`ffmpeg`](https://ffmpeg.org/): brew install ffmpeg ``` +Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use: + +``` +python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny +``` + +Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format. + +To generate a 4-bit quantized model, use `-q`. For a full list of options: + +``` +python convert.py --help +``` + +By default, the conversion script will make the directory `mlx_models/tiny` and save +the converted `weights.npz` and `config.json` there. + +> [!TIP] +> Alternatively, you can also download a few converted checkpoints from the +> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging +> Face and skip the conversion step. + ### Run Transcribe audio with: diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 2b4c237f..2e69ec45 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import os +import subprocess import sys import time @@ -12,6 +14,12 @@ audio_file = "whisper/assets/ls_test.flac" def parse_arguments(): parser = argparse.ArgumentParser(description="Benchmark script.") + parser.add_argument( + "--mlx-dir", + type=str, + default="mlx_models", + help="The folder of MLX models", + ) parser.add_argument( "--all", action="store_true", @@ -57,8 +65,8 @@ def decode(model, mels): return decoding.decode(model, mels) -def everything(model_name): - return transcribe(audio_file, model=model_name) +def everything(model_path): + return transcribe(audio_file, model_path=model_path) if __name__ == "__main__": @@ -76,6 +84,11 @@ if __name__ == "__main__": print(f"\nFeature time {feat_time:.3f}") for model_name in models: + model_path = f"{args.mlx_dir}/{model_name}" + if not os.path.exists(model_path): + print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion") + subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True) + print(f"\nModel: {model_name.upper()}") tokens = mx.array( [ @@ -110,12 +123,12 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(f"{model_name}", dtype=mx.float16) + model = load_models.load_model(model_path, dtype=mx.float16) mels = feats(model.dims.n_mels)[None].astype(mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels) print(f"Decode time {decode_time:.3f}") - everything_time = timer(everything, model_name) + everything_time = timer(everything, model_path) print(f"Everything time {everything_time:.3f}") print(f"\n{'-----' * 10}\n") diff --git a/whisper/convert.py b/whisper/convert.py new file mode 100644 index 00000000..48cbebc5 --- /dev/null +++ b/whisper/convert.py @@ -0,0 +1,284 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import copy +import hashlib +import json +import os +import urllib +import warnings +from dataclasses import asdict +from pathlib import Path +from typing import List + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import torch +from mlx.utils import tree_flatten, tree_map, tree_unflatten +from tqdm import tqdm + +from whisper import torch_whisper +from whisper.whisper import ModelDimensions, Whisper + +_VALID_DTYPES = {"float16", "float32"} + +_MODELS = { + "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", + "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", + "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", + "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", + "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", + "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", + "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", + "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", + "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", + "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", +} + +# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are +# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. +_ALIGNMENT_HEADS = { + "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", + "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", + "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", + "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", + "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", + "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", +} + + +def _download(url: str, root: str) -> str: + os.makedirs(root, exist_ok=True) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, os.path.basename(url)) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + with open(download_target, "rb") as f: + model_bytes = f.read() + if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(download_target, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + return download_target + + +def available_models() -> List[str]: + """Returns the names of available models""" + return list(_MODELS.keys()) + + +def load_torch_model( + name_or_path: str, + download_root: str = None, +) -> torch_whisper.Whisper: + """ + Load a Whisper ASR model + + Parameters + ---------- + name_or_path : str + one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format + download_root: str + path to download the model files; by default, it uses "~/.cache/whisper" + + Returns + ------- + model : Whisper + The Whisper ASR model instance + """ + + if download_root is None: + download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") + + # todo: accept alignment_heads of local Pytorch checkpoint + alignment_heads = None + if name_or_path in _MODELS: + alignment_heads = _ALIGNMENT_HEADS[name_or_path] + name_or_path = _download(_MODELS[name_or_path], download_root) + elif not Path(name_or_path).is_file(): + raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path") + + with open(name_or_path, "rb") as fp: + checkpoint = torch.load(fp) + + dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) + model = torch_whisper.Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + + return model + + +def convert(model, rules=None): + params = {} + if rules is not None and type(model) in rules: + out = rules[type(model)](model, rules) + return out + if isinstance(model, torch.Tensor): + return mx.array(model.detach().numpy()) + if isinstance(model, torch.nn.ModuleList): + return [convert(n, rules) for n in model.children()] + if isinstance(model, torch.nn.Conv1d): + return { + "weight": convert(model.weight).transpose(0, 2, 1), + "bias": convert(model.bias), + } + for k, n in model.named_children(): + if k in rules: + params.update(rules[k](n, rules)) + else: + params[k] = convert(n, rules) + for k, p in model.named_parameters(recurse=False): + params[k] = convert(p) + return params + + +def torch_to_mlx( + torch_model: torch_whisper.Whisper, + dtype: mx.Dtype = mx.float16, +) -> Whisper: + def convert_rblock(model, rules): + children = dict(model.named_children()) + mlp = list(children.pop("mlp").children()) + params = { + "mlp1": convert(mlp[0], rules), + "mlp2": convert(mlp[-1], rules), + } + for k, n in children.items(): + params[k] = convert(n, rules) + return params + + rules = { + torch_whisper.ResidualAttentionBlock: convert_rblock, + } + + params = convert(torch_model, rules) + + mlx_model = Whisper(torch_model.dims, dtype) + params = tree_map(lambda p: p.astype(dtype), params) + mlx_model.update(params) + return mlx_model + + +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model = Whisper(ModelDimensions(**config)) + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.") + parser.add_argument( + "--torch-name-or-path", + type=str, + default="tiny", + help="The name or path to the PyTorch model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_models/tiny", + help="The path to save the MLX model.", + ) + parser.add_argument( + "--dtype", + type=str, + default="float16", + help="The dtype to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) + args = parser.parse_args() + + assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}" + dtype = getattr(mx, args.dtype) + + print("[INFO] Loading") + model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype) + config = asdict(model.dims) + weights = dict(tree_flatten(model.parameters())) + + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + # Save weights + print("[INFO] Saving") + np.savez(str(mlx_path / "weights.npz"), **weights) + + # Save config.json with model_type + with open(str(mlx_path / "config.json"), "w") as f: + config["model_type"] = "whisper" + json.dump(config, f, indent=4) diff --git a/whisper/test.py b/whisper/test.py index 13a1f91a..3f81ce14 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -1,22 +1,68 @@ # Copyright © 2023 Apple Inc. +import json import os import subprocess import unittest +from dataclasses import asdict +from pathlib import Path import mlx.core as mx import numpy as np import torch +from mlx.utils import tree_flatten import whisper import whisper.audio as audio import whisper.decoding as decoding import whisper.load_models as load_models -import whisper.torch_whisper as torch_whisper +from convert import load_torch_model, quantize, torch_to_mlx + +MODEL_NAME = "tiny" +MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32" +MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16" +MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits" TEST_AUDIO = "whisper/assets/ls_test.flac" +def _save_model(save_dir, weights, config): + mlx_path = Path(save_dir) + mlx_path.mkdir(parents=True, exist_ok=True) + + # Save weights + np.savez(str(mlx_path / "weights.npz"), **weights) + + # Save config.json with model_type + with open(str(mlx_path / "config.json"), "w") as f: + config["model_type"] = "whisper" + json.dump(config, f, indent=4) + + config.pop("model_type", None) + + +def load_torch_and_mlx(): + torch_model = load_torch_model(MODEL_NAME) + + fp32_model = torch_to_mlx(torch_model, dtype=mx.float32) + config = asdict(fp32_model.dims) + weights = dict(tree_flatten(fp32_model.parameters())) + _save_model(MLX_FP32_MODEL_PATH, weights, config) + + fp16_model = torch_to_mlx(torch_model, dtype=mx.float16) + config = asdict(fp16_model.dims) + weights = dict(tree_flatten(fp16_model.parameters())) + _save_model(MLX_FP16_MODEL_PATH, weights, config) + + args = type("", (), {})() + args.q_group_size = 64 + args.q_bits = 4 + weights, config = quantize(weights, config, args) + _save_model(MLX_4BITS_MODEL_PATH, weights, config) + + return torch_model, fp32_model, fp16_model + + def forward_torch(model, mels, tokens): mels = torch.Tensor(mels).to(torch.float32) tokens = torch.Tensor(tokens).to(torch.int32) @@ -35,7 +81,7 @@ def forward_mlx(model, mels, tokens): class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = load_models.load_model("tiny", dtype=mx.float32) + _, cls.model, _ = load_torch_and_mlx() data = audio.load_audio(TEST_AUDIO) data = audio.pad_or_trim(data) cls.mels = audio.log_mel_spectrogram(data) @@ -43,7 +89,7 @@ class TestWhisper(unittest.TestCase): def test_torch_mlx(self): np.random.seed(10) - torch_model = load_models.load_torch_model("tiny") + torch_model = load_torch_model(MODEL_NAME) dims = torch_model.dims mels = np.random.randn(1, dims.n_mels, 3_000) @@ -51,19 +97,27 @@ class TestWhisper(unittest.TestCase): torch_logits = forward_torch(torch_model, mels, tokens) - mlx_model = load_models.torch_to_mlx(torch_model, mx.float32) - mlx_logits = forward_mlx(mlx_model, mels, tokens) + mlx_logits = forward_mlx(self.model, mels, tokens) self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2)) def test_fp16(self): - mlx_model = load_models.load_model("tiny", dtype=mx.float16) + mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16) dims = mlx_model.dims mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16) tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32) logits = mlx_model(mels, tokens) self.assertEqual(logits.dtype, mx.float16) + def test_quantized_4bits(self): + mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16) + dims = mlx_model.dims + mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16) + tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32) + logits = mlx_model(mels, tokens) + # Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription + self.assertEqual(logits.dtype, mx.float16) + def test_decode_lang(self): options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) @@ -135,7 +189,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe(TEST_AUDIO, fp16=False) + result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False) self.assertEqual( result["text"], ( @@ -154,7 +208,7 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe(audio_file, fp16=False) + result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") self.assertEqual(len(result["segments"]), 77) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 32e62119..c92d9042 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -1,198 +1,36 @@ # Copyright © 2023 Apple Inc. -import hashlib -import os -import urllib -import warnings -from typing import List +import json +from pathlib import Path import mlx.core as mx -import torch -from mlx.utils import tree_map -from tqdm import tqdm +import mlx.nn as nn +from mlx.utils import tree_unflatten -from . import torch_whisper, whisper - -_MODELS = { - "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", - "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", - "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", - "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", - "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", - "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", - "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", - "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", - "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", - "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", -} - -# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are -# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. -_ALIGNMENT_HEADS = { - "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", - "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", - "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", - "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", - "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", - "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", - "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", -} - - -def _download(url: str, root: str) -> str: - os.makedirs(root, exist_ok=True) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, os.path.basename(url)) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) - - return download_target - - -def available_models() -> List[str]: - """Returns the names of available models""" - return list(_MODELS.keys()) - - -def load_torch_model( - name: str, - download_root: str = None, -) -> torch_whisper.Whisper: - """ - Load a Whisper ASR model - - Parameters - ---------- - name : str - one of the official model names listed by `whisper.available_models()` - download_root: str - path to download the model files; by default, it uses "~/.cache/whisper" - - Returns - ------- - model : Whisper - The Whisper ASR model instance - """ - - if download_root is None: - download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - - if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root) - alignment_heads = _ALIGNMENT_HEADS[name] - else: - raise RuntimeError( - f"Model {name} not found; available models = {available_models()}" - ) - - with open(checkpoint_file, "rb") as fp: - checkpoint = torch.load(fp) - - dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) - model = torch_whisper.Whisper(dims) - model.load_state_dict(checkpoint["model_state_dict"]) - - if alignment_heads is not None: - model.set_alignment_heads(alignment_heads) - - return model - - -def convert(model, rules=None): - params = {} - if rules is not None and type(model) in rules: - out = rules[type(model)](model, rules) - return out - if isinstance(model, torch.Tensor): - return mx.array(model.detach().numpy()) - if isinstance(model, torch.nn.ModuleList): - return [convert(n, rules) for n in model.children()] - if isinstance(model, torch.nn.Conv1d): - return { - "weight": convert(model.weight).transpose(0, 2, 1), - "bias": convert(model.bias), - } - for k, n in model.named_children(): - if k in rules: - params.update(rules[k](n, rules)) - else: - params[k] = convert(n, rules) - for k, p in model.named_parameters(recurse=False): - params[k] = convert(p) - return params - - -def torch_to_mlx( - torch_model: torch_whisper.Whisper, - dtype: mx.Dtype = mx.float16, -) -> whisper.Whisper: - def convert_rblock(model, rules): - children = dict(model.named_children()) - mlp = list(children.pop("mlp").children()) - params = { - "mlp1": convert(mlp[0], rules), - "mlp2": convert(mlp[-1], rules), - } - for k, n in children.items(): - params[k] = convert(n, rules) - return params - - rules = { - torch_whisper.ResidualAttentionBlock: convert_rblock, - } - - params = convert(torch_model, rules) - - mlx_model = whisper.Whisper(torch_model.dims, dtype) - params = tree_map(lambda p: p.astype(dtype), params) - mlx_model.update(params) - return mlx_model +from . import whisper def load_model( - name: str, - download_root: str = None, + folder: str, dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: - return torch_to_mlx(load_torch_model(name, download_root), dtype) + model_path = Path(folder) + + with open(str(model_path / "config.json"), "r") as f: + config = json.loads(f.read()) + config.pop("model_type", None) + quantization = config.pop("quantization", None) + + model_args = whisper.ModelDimensions(**config) + + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + + model = whisper.Whisper(model_args, dtype) + + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.update(weights) + mx.eval(model.parameters()) + return model diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 67232136..330aef42 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -40,20 +40,20 @@ def _format_timestamp(seconds: float): class ModelHolder: model = None - model_name = None + model_path = None @classmethod - def get_model(cls, model: str, dtype: mx.Dtype): - if cls.model is None or model != cls.model_name: - cls.model = load_model(model, dtype=dtype) - cls.model_name = model + def get_model(cls, model_path: str, dtype: mx.Dtype): + if cls.model is None or model_path != cls.model_path: + cls.model = load_model(model_path, dtype=dtype) + cls.model_path = model_path return cls.model def transcribe( audio: Union[str, np.ndarray, mx.array], *, - model: str = "tiny", + model_path: str = "mlx_models/tiny", verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, @@ -73,9 +73,8 @@ def transcribe( audio: Union[str, np.ndarray, mx.array] The path to the audio file to open, or the audio waveform - model: str - The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"]. - Default is "tiny". + model_path: str + The path to the Whisper model that has been converted to MLX format. verbose: bool Whether to display the text being decoded to the console. If True, displays all the details, @@ -115,7 +114,7 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - model = ModelHolder.get_model(model, dtype) + model = ModelHolder.get_model(model_path, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)