From 21f28ccd550fad8fae7704bea5624f0cdc478408 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Fri, 29 Dec 2023 16:36:51 +0100 Subject: [PATCH] Separate model conversion and loading --- whisper/convert.py | 189 +++++++++++++++++++++++++++++- whisper/whisper/load_models.py | 203 +-------------------------------- whisper/whisper/transcribe.py | 20 ++-- 3 files changed, 198 insertions(+), 214 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index d0897b50..3ec9323c 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -2,19 +2,200 @@ import argparse import copy +import hashlib import json +import os +import urllib +import warnings from dataclasses import asdict from pathlib import Path +from typing import List import mlx.core as mx import mlx.nn as nn import numpy as np +import torch from mlx.utils import tree_flatten, tree_map, tree_unflatten +from tqdm import tqdm -from whisper.load_models import load_torch_model, torch_to_mlx +from whisper import torch_whisper from whisper.whisper import ModelDimensions, Whisper -MODEL_DTYPES = {"float16", "float32"} +_VALID_DTYPES = {"float16", "float32"} + +_MODELS = { + "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", + "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", + "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", + "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", + "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", + "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", + "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", + "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", + "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", + "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", +} + +# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are +# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. +_ALIGNMENT_HEADS = { + "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", + "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", + "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", + "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", + "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", + "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", +} + + +def _download(url: str, root: str) -> str: + os.makedirs(root, exist_ok=True) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, os.path.basename(url)) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + with open(download_target, "rb") as f: + model_bytes = f.read() + if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(download_target, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + return download_target + + +def available_models() -> List[str]: + """Returns the names of available models""" + return list(_MODELS.keys()) + + +def load_torch_model( + name_or_path: str, + download_root: str = None, +) -> torch_whisper.Whisper: + """ + Load a Whisper ASR model + + Parameters + ---------- + name_or_path : str + one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint in OpenAI's format + download_root: str + path to download the model files; by default, it uses "~/.cache/whisper" + + Returns + ------- + model : Whisper + The Whisper ASR model instance + """ + + if download_root is None: + download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") + + # todo: accept alignment_heads of local Pytorch checkpoint + alignment_heads = None + if name_or_path in _MODELS: + alignment_heads = _ALIGNMENT_HEADS[name_or_path] + name_or_path = _download(_MODELS[name_or_path], download_root) + elif not Path(name_or_path).is_file(): + raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path") + + with open(name_or_path, "rb") as fp: + checkpoint = torch.load(fp) + + dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) + model = torch_whisper.Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + + return model + + +def convert(model, rules=None): + params = {} + if rules is not None and type(model) in rules: + out = rules[type(model)](model, rules) + return out + if isinstance(model, torch.Tensor): + return mx.array(model.detach().numpy()) + if isinstance(model, torch.nn.ModuleList): + return [convert(n, rules) for n in model.children()] + if isinstance(model, torch.nn.Conv1d): + return { + "weight": convert(model.weight).transpose(0, 2, 1), + "bias": convert(model.bias), + } + for k, n in model.named_children(): + if k in rules: + params.update(rules[k](n, rules)) + else: + params[k] = convert(n, rules) + for k, p in model.named_parameters(recurse=False): + params[k] = convert(p) + return params + + +def torch_to_mlx( + torch_model: torch_whisper.Whisper, + dtype: mx.Dtype = mx.float16, +) -> Whisper: + def convert_rblock(model, rules): + children = dict(model.named_children()) + mlp = list(children.pop("mlp").children()) + params = { + "mlp1": convert(mlp[0], rules), + "mlp2": convert(mlp[-1], rules), + } + for k, n in children.items(): + params[k] = convert(n, rules) + return params + + rules = { + torch_whisper.ResidualAttentionBlock: convert_rblock, + } + + params = convert(torch_model, rules) + + mlx_model = Whisper(torch_model.dims, dtype) + params = tree_map(lambda p: p.astype(dtype), params) + mlx_model.update(params) + return mlx_model def quantize(weights, config, args): @@ -49,7 +230,7 @@ if __name__ == "__main__": parser.add_argument( "--mlx-path", type=str, - default="mlx_model", + default="mlx_models/tiny", help="The path to save the MLX model.", ) parser.add_argument( @@ -78,7 +259,7 @@ if __name__ == "__main__": ) args = parser.parse_args() - assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}" + assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}" dtype = getattr(mx, args.dtype) print("[INFO] Loading") diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 05adbe49..c92d9042 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -1,213 +1,20 @@ # Copyright © 2023 Apple Inc. -import glob -import hashlib import json -import os -import urllib -import warnings from pathlib import Path -from typing import List import mlx.core as mx import mlx.nn as nn -import torch -from mlx.utils import tree_map, tree_unflatten -from tqdm import tqdm +from mlx.utils import tree_unflatten -from . import torch_whisper, whisper - -_MODELS = { - "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", - "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", - "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", - "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", - "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", - "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", - "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", - "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", - "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", - "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", -} - -# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are -# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. -_ALIGNMENT_HEADS = { - "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", - "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", - "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", - "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", - "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", - "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", - "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", -} - - -def _download(url: str, root: str) -> str: - os.makedirs(root, exist_ok=True) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, os.path.basename(url)) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) - - return download_target - - -def available_models() -> List[str]: - """Returns the names of available models""" - return list(_MODELS.keys()) - - -def load_torch_model( - name_or_path: str, - download_root: str = None, -) -> torch_whisper.Whisper: - """ - Load a Whisper ASR model - - Parameters - ---------- - name_or_path : str - one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint - download_root: str - path to download the model files; by default, it uses "~/.cache/whisper" - - Returns - ------- - model : Whisper - The Whisper ASR model instance - """ - - if download_root is None: - download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - - if name_or_path in _MODELS: - name_or_path = _download(_MODELS[name_or_path], download_root) - elif not Path(name_or_path).is_file(): - raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path") - - alignment_heads = _ALIGNMENT_HEADS.get(name_or_path) - - with open(name_or_path, "rb") as fp: - checkpoint = torch.load(fp) - - dims = torch_whisper.ModelDimensions(**checkpoint["dims"]) - model = torch_whisper.Whisper(dims) - model.load_state_dict(checkpoint["model_state_dict"]) - - if alignment_heads is not None: - model.set_alignment_heads(alignment_heads) - - return model - - -def convert(model, rules=None): - params = {} - if rules is not None and type(model) in rules: - out = rules[type(model)](model, rules) - return out - if isinstance(model, torch.Tensor): - return mx.array(model.detach().numpy()) - if isinstance(model, torch.nn.ModuleList): - return [convert(n, rules) for n in model.children()] - if isinstance(model, torch.nn.Conv1d): - return { - "weight": convert(model.weight).transpose(0, 2, 1), - "bias": convert(model.bias), - } - for k, n in model.named_children(): - if k in rules: - params.update(rules[k](n, rules)) - else: - params[k] = convert(n, rules) - for k, p in model.named_parameters(recurse=False): - params[k] = convert(p) - return params - - -def torch_to_mlx( - torch_model: torch_whisper.Whisper, - dtype: mx.Dtype = mx.float16, -) -> whisper.Whisper: - def convert_rblock(model, rules): - children = dict(model.named_children()) - mlp = list(children.pop("mlp").children()) - params = { - "mlp1": convert(mlp[0], rules), - "mlp2": convert(mlp[-1], rules), - } - for k, n in children.items(): - params[k] = convert(n, rules) - return params - - rules = { - torch_whisper.ResidualAttentionBlock: convert_rblock, - } - - params = convert(torch_model, rules) - - mlx_model = whisper.Whisper(torch_model.dims, dtype) - params = tree_map(lambda p: p.astype(dtype), params) - mlx_model.update(params) - return mlx_model +from . import whisper def load_model( - name_or_path: str, - download_root: str = None, + folder: str, dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: - if name_or_path in _MODELS: - print(f"[INFO] Loading and converting {name_or_path} model") - return torch_to_mlx(load_torch_model(name_or_path, download_root), dtype) - elif not (glob.glob(f"{name_or_path}/weights.npz") and glob.glob(f"{name_or_path}/config.json")): - raise ValueError( - f"{name_or_path} not found in {available_models()}. Ensure that weights.npz and config.json files are" - " present in the specified path" - ) - - model_path = Path(name_or_path) + model_path = Path(folder) with open(str(model_path / "config.json"), "r") as f: config = json.loads(f.read()) @@ -225,7 +32,5 @@ def load_model( nn.QuantizedLinear.quantize_module(model, **quantization) model.update(weights) - mx.eval(model.parameters()) - return model diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 03e72963..330aef42 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -40,20 +40,20 @@ def _format_timestamp(seconds: float): class ModelHolder: model = None - model_name_or_path = None + model_path = None @classmethod - def get_model(cls, model_name_or_path: str, dtype: mx.Dtype): - if cls.model is None or model_name_or_path != cls.model_name_or_path: - cls.model = load_model(model_name_or_path, dtype=dtype) - cls.model_name_or_path = model_name_or_path + def get_model(cls, model_path: str, dtype: mx.Dtype): + if cls.model is None or model_path != cls.model_path: + cls.model = load_model(model_path, dtype=dtype) + cls.model_path = model_path return cls.model def transcribe( audio: Union[str, np.ndarray, mx.array], *, - model_name_or_path: str = "tiny", + model_path: str = "mlx_models/tiny", verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, @@ -73,10 +73,8 @@ def transcribe( audio: Union[str, np.ndarray, mx.array] The path to the audio file to open, or the audio waveform - model_name_or_path: str - The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"], - or a local folder in MLX format. - Default is "tiny". + model_path: str + The path to the Whisper model that has been converted to MLX format. verbose: bool Whether to display the text being decoded to the console. If True, displays all the details, @@ -116,7 +114,7 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - model = ModelHolder.get_model(model_name_or_path, dtype) + model = ModelHolder.get_model(model_path, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)