| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | import hashlib | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import urllib | 
					
						
							|  |  |  | import warnings | 
					
						
							|  |  |  | from dataclasses import asdict | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | from typing import List | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							|  |  |  | import mlx.nn as nn | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from mlx.utils import tree_flatten, tree_map, tree_unflatten | 
					
						
							| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  | from mlx_whisper import torch_whisper | 
					
						
							|  |  |  | from mlx_whisper.whisper import ModelDimensions, Whisper | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | _VALID_DTYPES = {"float16", "float32"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | _MODELS = { | 
					
						
							|  |  |  |     "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", | 
					
						
							|  |  |  |     "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", | 
					
						
							|  |  |  |     "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", | 
					
						
							|  |  |  |     "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", | 
					
						
							|  |  |  |     "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", | 
					
						
							|  |  |  |     "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", | 
					
						
							|  |  |  |     "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", | 
					
						
							|  |  |  |     "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", | 
					
						
							|  |  |  |     "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", | 
					
						
							|  |  |  |     "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | 
					
						
							|  |  |  |     "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", | 
					
						
							|  |  |  |     "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", | 
					
						
							| 
									
										
										
										
											2024-10-02 13:13:33 -07:00
										 |  |  |     "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", | 
					
						
							|  |  |  |     "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are | 
					
						
							|  |  |  | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. | 
					
						
							|  |  |  | _ALIGNMENT_HEADS = { | 
					
						
							|  |  |  |     "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", | 
					
						
							|  |  |  |     "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", | 
					
						
							|  |  |  |     "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", | 
					
						
							|  |  |  |     "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", | 
					
						
							|  |  |  |     "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", | 
					
						
							|  |  |  |     "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", | 
					
						
							|  |  |  |     "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00", | 
					
						
							|  |  |  |     "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", | 
					
						
							|  |  |  |     "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", | 
					
						
							|  |  |  |     "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", | 
					
						
							|  |  |  |     "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", | 
					
						
							|  |  |  |     "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", | 
					
						
							| 
									
										
										
										
											2024-10-02 13:13:33 -07:00
										 |  |  |     "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", | 
					
						
							|  |  |  |     "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _download(url: str, root: str) -> str: | 
					
						
							|  |  |  |     os.makedirs(root, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     expected_sha256 = url.split("/")[-2] | 
					
						
							|  |  |  |     download_target = os.path.join(root, os.path.basename(url)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if os.path.exists(download_target) and not os.path.isfile(download_target): | 
					
						
							|  |  |  |         raise RuntimeError(f"{download_target} exists and is not a regular file") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if os.path.isfile(download_target): | 
					
						
							|  |  |  |         with open(download_target, "rb") as f: | 
					
						
							|  |  |  |             model_bytes = f.read() | 
					
						
							|  |  |  |         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | 
					
						
							|  |  |  |             return download_target | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-01-03 15:13:26 -08:00
										 |  |  |             warnings.warn( | 
					
						
							|  |  |  |                 f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | 
					
						
							|  |  |  |         with tqdm( | 
					
						
							|  |  |  |             total=int(source.info().get("Content-Length")), | 
					
						
							|  |  |  |             ncols=80, | 
					
						
							|  |  |  |             unit="iB", | 
					
						
							|  |  |  |             unit_scale=True, | 
					
						
							|  |  |  |             unit_divisor=1024, | 
					
						
							|  |  |  |         ) as loop: | 
					
						
							|  |  |  |             while True: | 
					
						
							|  |  |  |                 buffer = source.read(8192) | 
					
						
							|  |  |  |                 if not buffer: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 output.write(buffer) | 
					
						
							|  |  |  |                 loop.update(len(buffer)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |     with open(download_target, "rb") as fid: | 
					
						
							|  |  |  |         model_bytes = fid.read() | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: | 
					
						
							|  |  |  |         raise RuntimeError( | 
					
						
							|  |  |  |             "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return download_target | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def available_models() -> List[str]: | 
					
						
							|  |  |  |     """Returns the names of available models""" | 
					
						
							|  |  |  |     return list(_MODELS.keys()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  | def hf_to_pt(weights, config): | 
					
						
							|  |  |  |     config = { | 
					
						
							|  |  |  |         "n_mels": config["num_mel_bins"], | 
					
						
							|  |  |  |         "n_audio_ctx": config["max_source_positions"], | 
					
						
							|  |  |  |         "n_audio_state": config["d_model"], | 
					
						
							|  |  |  |         "n_audio_head": config["encoder_attention_heads"], | 
					
						
							|  |  |  |         "n_audio_layer": config["encoder_layers"], | 
					
						
							|  |  |  |         "n_vocab": config["vocab_size"], | 
					
						
							|  |  |  |         "n_text_ctx": config["max_target_positions"], | 
					
						
							|  |  |  |         "n_text_state": config["d_model"], | 
					
						
							|  |  |  |         "n_text_head": config["decoder_attention_heads"], | 
					
						
							|  |  |  |         "n_text_layer": config["decoder_layers"], | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remap(k): | 
					
						
							|  |  |  |         k = k.replace("model.", "") | 
					
						
							|  |  |  |         k = k.replace(".layers", ".blocks") | 
					
						
							|  |  |  |         k = k.replace(".self_attn", ".attn") | 
					
						
							|  |  |  |         k = k.replace(".attn_layer_norm", ".attn_ln") | 
					
						
							|  |  |  |         k = k.replace(".encoder_attn.", ".cross_attn.") | 
					
						
							|  |  |  |         k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln") | 
					
						
							|  |  |  |         k = k.replace(".final_layer_norm", ".mlp_ln") | 
					
						
							|  |  |  |         k = k.replace(".q_proj", ".query") | 
					
						
							|  |  |  |         k = k.replace(".k_proj", ".key") | 
					
						
							|  |  |  |         k = k.replace(".v_proj", ".value") | 
					
						
							|  |  |  |         k = k.replace(".out_proj", ".out") | 
					
						
							|  |  |  |         k = k.replace(".fc1", ".mlp1") | 
					
						
							|  |  |  |         k = k.replace(".fc2", ".mlp2") | 
					
						
							|  |  |  |         k = k.replace("embed_positions.weight", "positional_embedding") | 
					
						
							|  |  |  |         k = k.replace("decoder.embed_tokens", "decoder.token_embedding") | 
					
						
							|  |  |  |         k = k.replace("encoder.layer_norm", "encoder.ln_post") | 
					
						
							|  |  |  |         k = k.replace("decoder.layer_norm", "decoder.ln") | 
					
						
							|  |  |  |         return k | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # token embeddings are shared with output projection | 
					
						
							|  |  |  |     weights.pop("proj_out.weight", None) | 
					
						
							|  |  |  |     weights = {remap(k): v for k, v in weights.items()} | 
					
						
							|  |  |  |     return weights, config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load_torch_weights_and_config( | 
					
						
							|  |  |  |     name_or_path: str, | 
					
						
							|  |  |  |     download_root: str = None, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if download_root is None: | 
					
						
							|  |  |  |         download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # todo: accept alignment_heads of local Pytorch checkpoint | 
					
						
							|  |  |  |     alignment_heads = None | 
					
						
							|  |  |  |     if name_or_path in _MODELS: | 
					
						
							|  |  |  |         alignment_heads = _ALIGNMENT_HEADS[name_or_path] | 
					
						
							|  |  |  |         name_or_path = _download(_MODELS[name_or_path], download_root) | 
					
						
							|  |  |  |     elif not Path(name_or_path).exists(): | 
					
						
							|  |  |  |         # Try downloading from HF | 
					
						
							|  |  |  |         from huggingface_hub import snapshot_download | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         name_or_path = snapshot_download( | 
					
						
							|  |  |  |             repo_id=name_or_path, | 
					
						
							| 
									
										
										
										
											2024-08-14 10:22:04 -07:00
										 |  |  |             allow_patterns=[ | 
					
						
							|  |  |  |                 "*.json", | 
					
						
							|  |  |  |                 "pytorch_model.bin", | 
					
						
							|  |  |  |                 "model.safetensors", | 
					
						
							|  |  |  |                 "*.txt", | 
					
						
							|  |  |  |             ], | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise RuntimeError( | 
					
						
							|  |  |  |             f"Model {name_or_path} is not found in {available_models()}," | 
					
						
							|  |  |  |             "on Hugging Face or as a local path." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if name_or_path.endswith(".pt"): | 
					
						
							| 
									
										
										
										
											2024-11-01 10:52:28 -07:00
										 |  |  |         checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         weights, config = checkpoint["model_state_dict"], checkpoint["dims"] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         name_or_path = Path(name_or_path) | 
					
						
							| 
									
										
										
										
											2024-08-14 10:22:04 -07:00
										 |  |  |         pt_path = name_or_path / "pytorch_model.bin" | 
					
						
							|  |  |  |         if pt_path.is_file(): | 
					
						
							|  |  |  |             weights = torch.load(pt_path, map_location="cpu") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             weights = mx.load(str(name_or_path / "model.safetensors")) | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         with open(name_or_path / "config.json", "r") as fp: | 
					
						
							|  |  |  |             config = json.load(fp) | 
					
						
							|  |  |  |         weights, config = hf_to_pt(weights, config) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return weights, config, alignment_heads | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | def load_torch_model( | 
					
						
							|  |  |  |     name_or_path: str, | 
					
						
							|  |  |  |     download_root: str = None, | 
					
						
							|  |  |  | ) -> torch_whisper.Whisper: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Load a Whisper ASR model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Parameters | 
					
						
							|  |  |  |     ---------- | 
					
						
							|  |  |  |     name_or_path : str | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         one of the official model names listed by `whisper.available_models()` or | 
					
						
							|  |  |  |         a local Pytorch checkpoint which is in the original OpenAI format | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     download_root: str | 
					
						
							|  |  |  |         path to download the model files; by default, it uses "~/.cache/whisper" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns | 
					
						
							|  |  |  |     ------- | 
					
						
							|  |  |  |     model : Whisper | 
					
						
							|  |  |  |         The Whisper ASR model instance | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if download_root is None: | 
					
						
							|  |  |  |         download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     weights, config, alignment_heads = load_torch_weights_and_config( | 
					
						
							|  |  |  |         name_or_path, download_root | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     dims = torch_whisper.ModelDimensions(**config) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     model = torch_whisper.Whisper(dims) | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     model.load_state_dict(weights) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if alignment_heads is not None: | 
					
						
							|  |  |  |         model.set_alignment_heads(alignment_heads) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  | def convert(name_or_path: str, dtype: mx.Dtype = mx.float16): | 
					
						
							|  |  |  |     def remap(key, value): | 
					
						
							|  |  |  |         key = key.replace("mlp.0", "mlp1") | 
					
						
							|  |  |  |         key = key.replace("mlp.2", "mlp2") | 
					
						
							|  |  |  |         if "conv" in key and value.ndim == 3: | 
					
						
							|  |  |  |             value = value.swapaxes(1, 2) | 
					
						
							| 
									
										
										
										
											2024-08-14 10:22:04 -07:00
										 |  |  |         if isinstance(value, torch.Tensor): | 
					
						
							|  |  |  |             value = mx.array(value.detach()) | 
					
						
							|  |  |  |         return key, value.astype(dtype) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     weights, config, alignment_heads = load_torch_weights_and_config(name_or_path) | 
					
						
							|  |  |  |     weights.pop("encoder.positional_embedding", None) | 
					
						
							|  |  |  |     weights = dict(remap(k, v) for k, v in weights.items()) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     model_dims = ModelDimensions(**config) | 
					
						
							|  |  |  |     model = Whisper(model_dims, dtype) | 
					
						
							|  |  |  |     model.load_weights(list(weights.items()), strict=False) | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     if alignment_heads is not None: | 
					
						
							|  |  |  |         model.set_alignment_heads(alignment_heads) | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     return model | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  | def upload_to_hub(path: str, name: str, torch_name_or_path: str): | 
					
						
							|  |  |  |     import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     from huggingface_hub import HfApi, ModelCard, logging | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     repo_id = f"mlx-community/{name}" | 
					
						
							|  |  |  |     text = f"""
 | 
					
						
							|  |  |  | --- | 
					
						
							|  |  |  | library_name: mlx | 
					
						
							|  |  |  | --- | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # {name} | 
					
						
							|  |  |  | This model was converted to MLX format from [`{torch_name_or_path}`](). | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ## Use with mlx | 
					
						
							|  |  |  | ```bash | 
					
						
							| 
									
										
										
										
											2024-08-14 10:22:04 -07:00
										 |  |  | pip install mlx-whisper | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ```python | 
					
						
							|  |  |  | import mlx_whisper | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-14 10:22:04 -07:00
										 |  |  | result = mlx_whisper.transcribe( | 
					
						
							|  |  |  |     "FILE_NAME", | 
					
						
							|  |  |  |     path_or_hf_repo={repo_id}, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  | ``` | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  |     card = ModelCard(text) | 
					
						
							|  |  |  |     card.save(os.path.join(path, "README.md")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     logging.set_verbosity_info() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     api = HfApi() | 
					
						
							|  |  |  |     api.create_repo(repo_id=repo_id, exist_ok=True) | 
					
						
							|  |  |  |     api.upload_folder( | 
					
						
							|  |  |  |         folder_path=path, | 
					
						
							|  |  |  |         repo_id=repo_id, | 
					
						
							|  |  |  |         repo_type="model", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | def quantize(weights, config, args): | 
					
						
							|  |  |  |     quantized_config = copy.deepcopy(config) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Load the model: | 
					
						
							|  |  |  |     model = Whisper(ModelDimensions(**config)) | 
					
						
							|  |  |  |     weights = tree_map(mx.array, weights) | 
					
						
							|  |  |  |     model.update(tree_unflatten(list(weights.items()))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Quantize the model: | 
					
						
							| 
									
										
										
										
											2024-04-18 18:16:10 -07:00
										 |  |  |     nn.quantize(model, args.q_group_size, args.q_bits) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Update the config: | 
					
						
							|  |  |  |     quantized_config["quantization"] = { | 
					
						
							|  |  |  |         "group_size": args.q_group_size, | 
					
						
							|  |  |  |         "bits": args.q_bits, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     quantized_weights = dict(tree_flatten(model.parameters())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return quantized_weights, quantized_config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--torch-name-or-path", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="tiny", | 
					
						
							|  |  |  |         help="The name or path to the PyTorch model.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--mlx-path", | 
					
						
							|  |  |  |         type=str, | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  |         default="mlx_models", | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |         help="The path to save the MLX model.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--dtype", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="float16", | 
					
						
							|  |  |  |         help="The dtype to save the MLX model.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "-q", | 
					
						
							|  |  |  |         "--quantize", | 
					
						
							|  |  |  |         help="Generate a quantized model.", | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         "--q-group-size", | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |         help="Group size for quantization.", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=64, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |         "--q-bits", | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |         help="Bits per weight for quantization.", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=4, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--upload-name", | 
					
						
							|  |  |  |         help="The name of model to upload to Hugging Face MLX Community", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-03 15:13:26 -08:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         args.dtype in _VALID_DTYPES | 
					
						
							|  |  |  |     ), f"dtype {args.dtype} not found in {_VALID_DTYPES}" | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     dtype = getattr(mx, args.dtype) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("[INFO] Loading") | 
					
						
							| 
									
										
										
										
											2024-08-09 11:11:58 -07:00
										 |  |  |     model = convert(args.torch_name_or_path, dtype) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     config = asdict(model.dims) | 
					
						
							|  |  |  |     weights = dict(tree_flatten(model.parameters())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.quantize: | 
					
						
							|  |  |  |         print("[INFO] Quantizing") | 
					
						
							|  |  |  |         weights, config = quantize(weights, config, args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mlx_path = Path(args.mlx_path) | 
					
						
							|  |  |  |     mlx_path.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Save weights | 
					
						
							|  |  |  |     print("[INFO] Saving") | 
					
						
							| 
									
										
										
										
											2024-11-01 10:52:28 -07:00
										 |  |  |     mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Save config.json with model_type | 
					
						
							|  |  |  |     with open(str(mlx_path / "config.json"), "w") as f: | 
					
						
							|  |  |  |         config["model_type"] = "whisper" | 
					
						
							|  |  |  |         json.dump(config, f, indent=4) | 
					
						
							| 
									
										
										
										
											2024-01-08 19:48:24 +05:30
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.upload_name is not None: | 
					
						
							|  |  |  |         upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path) |