2024-05-02 00:00:02 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
import argparse
|
|
|
|
import copy
|
|
|
|
import hashlib
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import urllib
|
|
|
|
import warnings
|
|
|
|
from dataclasses import asdict
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
2024-05-02 00:00:02 +08:00
|
|
|
from mlx_whisper import torch_whisper
|
|
|
|
from mlx_whisper.whisper import ModelDimensions, Whisper
|
2023-12-30 02:22:15 +08:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
_VALID_DTYPES = {"float16", "float32"}
|
|
|
|
|
|
|
|
_MODELS = {
|
|
|
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
|
|
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
|
|
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
|
|
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
|
|
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
|
|
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
|
|
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
|
|
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
|
|
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
|
|
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
|
|
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
|
|
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
2024-10-03 04:13:33 +08:00
|
|
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
|
|
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
2023-12-30 02:22:15 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
|
|
|
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
|
|
|
_ALIGNMENT_HEADS = {
|
|
|
|
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
|
|
|
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
|
|
|
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
|
|
|
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
|
|
|
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
|
|
|
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
|
|
|
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
|
|
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
|
|
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
|
|
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
|
|
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
|
|
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
2024-10-03 04:13:33 +08:00
|
|
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
|
|
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
2023-12-30 02:22:15 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _download(url: str, root: str) -> str:
|
|
|
|
os.makedirs(root, exist_ok=True)
|
|
|
|
|
|
|
|
expected_sha256 = url.split("/")[-2]
|
|
|
|
download_target = os.path.join(root, os.path.basename(url))
|
|
|
|
|
|
|
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
|
|
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
|
|
|
|
|
|
|
if os.path.isfile(download_target):
|
|
|
|
with open(download_target, "rb") as f:
|
|
|
|
model_bytes = f.read()
|
|
|
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
|
|
|
return download_target
|
|
|
|
else:
|
2024-01-04 07:13:26 +08:00
|
|
|
warnings.warn(
|
|
|
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
|
|
|
)
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
|
|
|
with tqdm(
|
|
|
|
total=int(source.info().get("Content-Length")),
|
|
|
|
ncols=80,
|
|
|
|
unit="iB",
|
|
|
|
unit_scale=True,
|
|
|
|
unit_divisor=1024,
|
|
|
|
) as loop:
|
|
|
|
while True:
|
|
|
|
buffer = source.read(8192)
|
|
|
|
if not buffer:
|
|
|
|
break
|
|
|
|
|
|
|
|
output.write(buffer)
|
|
|
|
loop.update(len(buffer))
|
|
|
|
|
2024-03-31 04:13:58 +08:00
|
|
|
with open(download_target, "rb") as fid:
|
|
|
|
model_bytes = fid.read()
|
2023-12-30 02:22:15 +08:00
|
|
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
|
|
|
)
|
|
|
|
|
|
|
|
return download_target
|
|
|
|
|
|
|
|
|
|
|
|
def available_models() -> List[str]:
|
|
|
|
"""Returns the names of available models"""
|
|
|
|
return list(_MODELS.keys())
|
|
|
|
|
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
def hf_to_pt(weights, config):
|
|
|
|
config = {
|
|
|
|
"n_mels": config["num_mel_bins"],
|
|
|
|
"n_audio_ctx": config["max_source_positions"],
|
|
|
|
"n_audio_state": config["d_model"],
|
|
|
|
"n_audio_head": config["encoder_attention_heads"],
|
|
|
|
"n_audio_layer": config["encoder_layers"],
|
|
|
|
"n_vocab": config["vocab_size"],
|
|
|
|
"n_text_ctx": config["max_target_positions"],
|
|
|
|
"n_text_state": config["d_model"],
|
|
|
|
"n_text_head": config["decoder_attention_heads"],
|
|
|
|
"n_text_layer": config["decoder_layers"],
|
|
|
|
}
|
|
|
|
|
|
|
|
def remap(k):
|
|
|
|
k = k.replace("model.", "")
|
|
|
|
k = k.replace(".layers", ".blocks")
|
|
|
|
k = k.replace(".self_attn", ".attn")
|
|
|
|
k = k.replace(".attn_layer_norm", ".attn_ln")
|
|
|
|
k = k.replace(".encoder_attn.", ".cross_attn.")
|
|
|
|
k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
|
|
|
|
k = k.replace(".final_layer_norm", ".mlp_ln")
|
|
|
|
k = k.replace(".q_proj", ".query")
|
|
|
|
k = k.replace(".k_proj", ".key")
|
|
|
|
k = k.replace(".v_proj", ".value")
|
|
|
|
k = k.replace(".out_proj", ".out")
|
|
|
|
k = k.replace(".fc1", ".mlp1")
|
|
|
|
k = k.replace(".fc2", ".mlp2")
|
|
|
|
k = k.replace("embed_positions.weight", "positional_embedding")
|
|
|
|
k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
|
|
|
|
k = k.replace("encoder.layer_norm", "encoder.ln_post")
|
|
|
|
k = k.replace("decoder.layer_norm", "decoder.ln")
|
|
|
|
return k
|
|
|
|
|
|
|
|
# token embeddings are shared with output projection
|
|
|
|
weights.pop("proj_out.weight", None)
|
|
|
|
weights = {remap(k): v for k, v in weights.items()}
|
|
|
|
return weights, config
|
|
|
|
|
|
|
|
|
|
|
|
def load_torch_weights_and_config(
|
|
|
|
name_or_path: str,
|
|
|
|
download_root: str = None,
|
|
|
|
):
|
|
|
|
if download_root is None:
|
|
|
|
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
|
|
|
|
|
|
|
# todo: accept alignment_heads of local Pytorch checkpoint
|
|
|
|
alignment_heads = None
|
|
|
|
if name_or_path in _MODELS:
|
|
|
|
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
|
|
|
name_or_path = _download(_MODELS[name_or_path], download_root)
|
|
|
|
elif not Path(name_or_path).exists():
|
|
|
|
# Try downloading from HF
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
|
name_or_path = snapshot_download(
|
|
|
|
repo_id=name_or_path,
|
2024-08-15 01:22:04 +08:00
|
|
|
allow_patterns=[
|
|
|
|
"*.json",
|
|
|
|
"pytorch_model.bin",
|
|
|
|
"model.safetensors",
|
|
|
|
"*.txt",
|
|
|
|
],
|
2024-08-10 02:11:58 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
if name_or_path.endswith(".pt"):
|
2024-11-02 01:52:28 +08:00
|
|
|
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
|
2024-08-10 02:11:58 +08:00
|
|
|
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
|
|
|
else:
|
|
|
|
name_or_path = Path(name_or_path)
|
2024-08-15 01:22:04 +08:00
|
|
|
pt_path = name_or_path / "pytorch_model.bin"
|
|
|
|
if pt_path.is_file():
|
|
|
|
weights = torch.load(pt_path, map_location="cpu")
|
|
|
|
else:
|
|
|
|
weights = mx.load(str(name_or_path / "model.safetensors"))
|
2024-08-10 02:11:58 +08:00
|
|
|
with open(name_or_path / "config.json", "r") as fp:
|
|
|
|
config = json.load(fp)
|
|
|
|
weights, config = hf_to_pt(weights, config)
|
|
|
|
|
|
|
|
return weights, config, alignment_heads
|
|
|
|
|
|
|
|
|
2023-12-30 02:22:15 +08:00
|
|
|
def load_torch_model(
|
|
|
|
name_or_path: str,
|
|
|
|
download_root: str = None,
|
|
|
|
) -> torch_whisper.Whisper:
|
|
|
|
"""
|
|
|
|
Load a Whisper ASR model
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
name_or_path : str
|
2024-08-10 02:11:58 +08:00
|
|
|
one of the official model names listed by `whisper.available_models()` or
|
|
|
|
a local Pytorch checkpoint which is in the original OpenAI format
|
2023-12-30 02:22:15 +08:00
|
|
|
download_root: str
|
|
|
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
model : Whisper
|
|
|
|
The Whisper ASR model instance
|
|
|
|
"""
|
|
|
|
|
|
|
|
if download_root is None:
|
|
|
|
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
weights, config, alignment_heads = load_torch_weights_and_config(
|
|
|
|
name_or_path, download_root
|
|
|
|
)
|
|
|
|
dims = torch_whisper.ModelDimensions(**config)
|
2023-12-30 02:22:15 +08:00
|
|
|
model = torch_whisper.Whisper(dims)
|
2024-08-10 02:11:58 +08:00
|
|
|
model.load_state_dict(weights)
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
if alignment_heads is not None:
|
|
|
|
model.set_alignment_heads(alignment_heads)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
|
|
|
|
def remap(key, value):
|
|
|
|
key = key.replace("mlp.0", "mlp1")
|
|
|
|
key = key.replace("mlp.2", "mlp2")
|
|
|
|
if "conv" in key and value.ndim == 3:
|
|
|
|
value = value.swapaxes(1, 2)
|
2024-08-15 01:22:04 +08:00
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
|
value = mx.array(value.detach())
|
|
|
|
return key, value.astype(dtype)
|
2023-12-30 02:22:15 +08:00
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
|
|
|
|
weights.pop("encoder.positional_embedding", None)
|
|
|
|
weights = dict(remap(k, v) for k, v in weights.items())
|
2023-12-30 02:22:15 +08:00
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
model_dims = ModelDimensions(**config)
|
|
|
|
model = Whisper(model_dims, dtype)
|
|
|
|
model.load_weights(list(weights.items()), strict=False)
|
2024-01-08 02:01:29 +08:00
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
if alignment_heads is not None:
|
|
|
|
model.set_alignment_heads(alignment_heads)
|
2024-01-08 02:01:29 +08:00
|
|
|
|
2024-08-10 02:11:58 +08:00
|
|
|
return model
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
|
2024-01-08 22:18:24 +08:00
|
|
|
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
|
|
|
|
import os
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
|
|
|
|
repo_id = f"mlx-community/{name}"
|
|
|
|
text = f"""
|
|
|
|
---
|
|
|
|
library_name: mlx
|
|
|
|
---
|
|
|
|
|
|
|
|
# {name}
|
|
|
|
This model was converted to MLX format from [`{torch_name_or_path}`]().
|
|
|
|
|
|
|
|
## Use with mlx
|
|
|
|
```bash
|
2024-08-15 01:22:04 +08:00
|
|
|
pip install mlx-whisper
|
|
|
|
```
|
|
|
|
|
|
|
|
```python
|
|
|
|
import mlx_whisper
|
2024-01-08 22:18:24 +08:00
|
|
|
|
2024-08-15 01:22:04 +08:00
|
|
|
result = mlx_whisper.transcribe(
|
|
|
|
"FILE_NAME",
|
|
|
|
path_or_hf_repo={repo_id},
|
|
|
|
)
|
2024-01-08 22:18:24 +08:00
|
|
|
```
|
|
|
|
"""
|
|
|
|
card = ModelCard(text)
|
|
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
|
|
|
|
api = HfApi()
|
|
|
|
api.create_repo(repo_id=repo_id, exist_ok=True)
|
|
|
|
api.upload_folder(
|
|
|
|
folder_path=path,
|
|
|
|
repo_id=repo_id,
|
|
|
|
repo_type="model",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-30 02:22:15 +08:00
|
|
|
def quantize(weights, config, args):
|
|
|
|
quantized_config = copy.deepcopy(config)
|
|
|
|
|
|
|
|
# Load the model:
|
|
|
|
model = Whisper(ModelDimensions(**config))
|
|
|
|
weights = tree_map(mx.array, weights)
|
|
|
|
model.update(tree_unflatten(list(weights.items())))
|
|
|
|
|
|
|
|
# Quantize the model:
|
2024-04-19 09:16:10 +08:00
|
|
|
nn.quantize(model, args.q_group_size, args.q_bits)
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
# Update the config:
|
|
|
|
quantized_config["quantization"] = {
|
|
|
|
"group_size": args.q_group_size,
|
|
|
|
"bits": args.q_bits,
|
|
|
|
}
|
|
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--torch-name-or-path",
|
|
|
|
type=str,
|
|
|
|
default="tiny",
|
|
|
|
help="The name or path to the PyTorch model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-path",
|
|
|
|
type=str,
|
2024-01-08 22:18:24 +08:00
|
|
|
default="mlx_models",
|
2023-12-30 02:22:15 +08:00
|
|
|
help="The path to save the MLX model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dtype",
|
|
|
|
type=str,
|
|
|
|
default="float16",
|
|
|
|
help="The dtype to save the MLX model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-q",
|
|
|
|
"--quantize",
|
|
|
|
help="Generate a quantized model.",
|
|
|
|
action="store_true",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-08-10 02:11:58 +08:00
|
|
|
"--q-group-size",
|
2023-12-30 02:22:15 +08:00
|
|
|
help="Group size for quantization.",
|
|
|
|
type=int,
|
|
|
|
default=64,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
2024-08-10 02:11:58 +08:00
|
|
|
"--q-bits",
|
2023-12-30 02:22:15 +08:00
|
|
|
help="Bits per weight for quantization.",
|
|
|
|
type=int,
|
|
|
|
default=4,
|
|
|
|
)
|
2024-01-08 22:18:24 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--upload-name",
|
|
|
|
help="The name of model to upload to Hugging Face MLX Community",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
)
|
|
|
|
|
2023-12-30 02:22:15 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2024-01-04 07:13:26 +08:00
|
|
|
assert (
|
|
|
|
args.dtype in _VALID_DTYPES
|
|
|
|
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
2023-12-30 02:22:15 +08:00
|
|
|
dtype = getattr(mx, args.dtype)
|
|
|
|
|
|
|
|
print("[INFO] Loading")
|
2024-08-10 02:11:58 +08:00
|
|
|
model = convert(args.torch_name_or_path, dtype)
|
2023-12-30 02:22:15 +08:00
|
|
|
config = asdict(model.dims)
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
|
|
|
if args.quantize:
|
|
|
|
print("[INFO] Quantizing")
|
|
|
|
weights, config = quantize(weights, config, args)
|
|
|
|
|
|
|
|
mlx_path = Path(args.mlx_path)
|
|
|
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
# Save weights
|
|
|
|
print("[INFO] Saving")
|
2024-11-02 01:52:28 +08:00
|
|
|
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
2023-12-30 02:22:15 +08:00
|
|
|
|
|
|
|
# Save config.json with model_type
|
|
|
|
with open(str(mlx_path / "config.json"), "w") as f:
|
|
|
|
config["model_type"] = "whisper"
|
|
|
|
json.dump(config, f, indent=4)
|
2024-01-08 22:18:24 +08:00
|
|
|
|
|
|
|
if args.upload_name is not None:
|
|
|
|
upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)
|