# Copyright © 2023-2024 Apple Inc. import argparse import copy import hashlib import json import os import urllib import warnings from dataclasses import asdict from pathlib import Path from typing import List import mlx.core as mx import mlx.nn as nn import numpy as np import torch from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx_whisper import torch_whisper from mlx_whisper.whisper import ModelDimensions, Whisper from tqdm import tqdm _VALID_DTYPES = {"float16", "float32"} _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. _ALIGNMENT_HEADS = { "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } def _download(url: str, root: str) -> str: os.makedirs(root, exist_ok=True) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, os.path.basename(url)) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): with open(download_target, "rb") as f: model_bytes = f.read() if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024, ) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) with open(download_target, "rb") as fid: model_bytes = fid.read() if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) return download_target def available_models() -> List[str]: """Returns the names of available models""" return list(_MODELS.keys()) def hf_to_pt(weights, config): config = { "n_mels": config["num_mel_bins"], "n_audio_ctx": config["max_source_positions"], "n_audio_state": config["d_model"], "n_audio_head": config["encoder_attention_heads"], "n_audio_layer": config["encoder_layers"], "n_vocab": config["vocab_size"], "n_text_ctx": config["max_target_positions"], "n_text_state": config["d_model"], "n_text_head": config["decoder_attention_heads"], "n_text_layer": config["decoder_layers"], } def remap(k): k = k.replace("model.", "") k = k.replace(".layers", ".blocks") k = k.replace(".self_attn", ".attn") k = k.replace(".attn_layer_norm", ".attn_ln") k = k.replace(".encoder_attn.", ".cross_attn.") k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln") k = k.replace(".final_layer_norm", ".mlp_ln") k = k.replace(".q_proj", ".query") k = k.replace(".k_proj", ".key") k = k.replace(".v_proj", ".value") k = k.replace(".out_proj", ".out") k = k.replace(".fc1", ".mlp1") k = k.replace(".fc2", ".mlp2") k = k.replace("embed_positions.weight", "positional_embedding") k = k.replace("decoder.embed_tokens", "decoder.token_embedding") k = k.replace("encoder.layer_norm", "encoder.ln_post") k = k.replace("decoder.layer_norm", "decoder.ln") return k # token embeddings are shared with output projection weights.pop("proj_out.weight", None) weights = {remap(k): v for k, v in weights.items()} return weights, config def load_torch_weights_and_config( name_or_path: str, download_root: str = None, ): if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") # todo: accept alignment_heads of local Pytorch checkpoint alignment_heads = None if name_or_path in _MODELS: alignment_heads = _ALIGNMENT_HEADS[name_or_path] name_or_path = _download(_MODELS[name_or_path], download_root) elif not Path(name_or_path).exists(): # Try downloading from HF from huggingface_hub import snapshot_download name_or_path = snapshot_download( repo_id=name_or_path, allow_patterns=["*.json", "pytorch_model.bin", "*.txt"], ) else: raise RuntimeError( f"Model {name_or_path} is not found in {available_models()}," "on Hugging Face or as a local path." ) if name_or_path.endswith(".pt"): checkpoint = torch.load(name_or_path, map_location="cpu") weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) weights = torch.load( name_or_path / "pytorch_model.bin", map_location="cpu", ) with open(name_or_path / "config.json", "r") as fp: config = json.load(fp) weights, config = hf_to_pt(weights, config) return weights, config, alignment_heads def load_torch_model( name_or_path: str, download_root: str = None, ) -> torch_whisper.Whisper: """ Load a Whisper ASR model Parameters ---------- name_or_path : str one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format download_root: str path to download the model files; by default, it uses "~/.cache/whisper" Returns ------- model : Whisper The Whisper ASR model instance """ if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") weights, config, alignment_heads = load_torch_weights_and_config( name_or_path, download_root ) dims = torch_whisper.ModelDimensions(**config) model = torch_whisper.Whisper(dims) model.load_state_dict(weights) if alignment_heads is not None: model.set_alignment_heads(alignment_heads) return model def convert(name_or_path: str, dtype: mx.Dtype = mx.float16): def remap(key, value): key = key.replace("mlp.0", "mlp1") key = key.replace("mlp.2", "mlp2") if "conv" in key and value.ndim == 3: value = value.swapaxes(1, 2) return key, mx.array(value.detach()).astype(dtype) weights, config, alignment_heads = load_torch_weights_and_config(name_or_path) weights.pop("encoder.positional_embedding", None) weights = dict(remap(k, v) for k, v in weights.items()) model_dims = ModelDimensions(**config) model = Whisper(model_dims, dtype) model.load_weights(list(weights.items()), strict=False) if alignment_heads is not None: model.set_alignment_heads(alignment_heads) return model def upload_to_hub(path: str, name: str, torch_name_or_path: str): import os from huggingface_hub import HfApi, ModelCard, logging repo_id = f"mlx-community/{name}" text = f""" --- library_name: mlx --- # {name} This model was converted to MLX format from [`{torch_name_or_path}`](). ## Use with mlx ```bash git clone https://github.com/ml-explore/mlx-examples.git cd mlx-examples/whisper/ pip install -r requirements.txt >> import whisper >> whisper.transcribe("FILE_NAME") ``` """ card = ModelCard(text) card.save(os.path.join(path, "README.md")) logging.set_verbosity_info() api = HfApi() api.create_repo(repo_id=repo_id, exist_ok=True) api.upload_folder( folder_path=path, repo_id=repo_id, repo_type="model", ) def quantize(weights, config, args): quantized_config = copy.deepcopy(config) # Load the model: model = Whisper(ModelDimensions(**config)) weights = tree_map(mx.array, weights) model.update(tree_unflatten(list(weights.items()))) # Quantize the model: nn.quantize(model, args.q_group_size, args.q_bits) # Update the config: quantized_config["quantization"] = { "group_size": args.q_group_size, "bits": args.q_bits, } quantized_weights = dict(tree_flatten(model.parameters())) return quantized_weights, quantized_config if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.") parser.add_argument( "--torch-name-or-path", type=str, default="tiny", help="The name or path to the PyTorch model.", ) parser.add_argument( "--mlx-path", type=str, default="mlx_models", help="The path to save the MLX model.", ) parser.add_argument( "--dtype", type=str, default="float16", help="The dtype to save the MLX model.", ) parser.add_argument( "-q", "--quantize", help="Generate a quantized model.", action="store_true", ) parser.add_argument( "--q-group-size", help="Group size for quantization.", type=int, default=64, ) parser.add_argument( "--q-bits", help="Bits per weight for quantization.", type=int, default=4, ) parser.add_argument( "--upload-name", help="The name of model to upload to Hugging Face MLX Community", type=str, default=None, ) args = parser.parse_args() assert ( args.dtype in _VALID_DTYPES ), f"dtype {args.dtype} not found in {_VALID_DTYPES}" dtype = getattr(mx, args.dtype) print("[INFO] Loading") model = convert(args.torch_name_or_path, dtype) config = asdict(model.dims) weights = dict(tree_flatten(model.parameters())) if args.quantize: print("[INFO] Quantizing") weights, config = quantize(weights, config, args) mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) # Save weights print("[INFO] Saving") np.savez(str(mlx_path / "weights.npz"), **weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: config["model_type"] = "whisper" json.dump(config, f, indent=4) if args.upload_name is not None: upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)