diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5240881b..512e1eaf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 22.10.0 + rev: 23.12.1 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/gcn/main.py b/gcn/main.py index c3c8795a..7d041b66 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -32,7 +32,6 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay): def main(args): - # Data loading x, y, adj = load_data(args) train_mask, val_mask, test_mask = train_val_test_mask() @@ -55,7 +54,6 @@ def main(args): # Training loop for epoch in range(args.epochs): - # Loss (loss, y_hat), grads = loss_and_grad_fn( gcn, x, adj, y, train_mask, args.weight_decay @@ -96,7 +94,6 @@ def main(args): if __name__ == "__main__": - parser = ArgumentParser() parser.add_argument("--nodes_path", type=str, default="cora/cora.content") parser.add_argument("--edges_path", type=str, default="cora/cora.cites") diff --git a/llms/deepseek-coder/convert.py b/llms/deepseek-coder/convert.py index 9ffa52c3..d3e18ec7 100644 --- a/llms/deepseek-coder/convert.py +++ b/llms/deepseek-coder/convert.py @@ -44,7 +44,9 @@ def convert(args): config = model.config.to_dict() state_dict = model.state_dict() - tokenizer = AutoTokenizer.from_pretrained(str(hf_path), trust_remote_code=True, use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + str(hf_path), trust_remote_code=True, use_fast=False + ) # things to change # 1. there's no "model." in the weight names @@ -84,7 +86,9 @@ def convert(args): weights = {k: v.numpy() for k, v in state_dict.items()} - config["rope_scaling_factor"] = config["rope_scaling"]["factor"] if config["rope_scaling"] is not None else 1.0 + config["rope_scaling_factor"] = ( + config["rope_scaling"]["factor"] if config["rope_scaling"] is not None else 1.0 + ) keep_keys = set( [ "vocab_size", @@ -96,7 +100,7 @@ def convert(args): "rms_norm_eps", "intermediate_size", "rope_scaling_factor", - "rope_theta" + "rope_theta", ] ) for k in list(config.keys()): diff --git a/llms/deepseek-coder/deepseek_coder.py b/llms/deepseek-coder/deepseek_coder.py index 9b8a8a3e..6c878b55 100644 --- a/llms/deepseek-coder/deepseek_coder.py +++ b/llms/deepseek-coder/deepseek_coder.py @@ -285,7 +285,11 @@ if __name__ == "__main__": model, tokenizer = load_model(args.model_path) - prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[ + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )[ "input_ids" ][0] diff --git a/llms/hf_llm/README.md b/llms/hf_llm/README.md new file mode 100644 index 00000000..b2e2667f --- /dev/null +++ b/llms/hf_llm/README.md @@ -0,0 +1,75 @@ +## Generate Text with MLX and :hugs: Hugging Face + +This an example large language model text generation that can pull models from +the Hugging Face Hub. + +### Setup + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +### Run + +``` +python generate.py --model --prompt "hello" +``` + +For example: + +``` +python generate.py --model mistralai/Mistral-7B-v0.1 --prompt "hello" +``` + +will download the Mistral 7B model and generate text using the given prompt. + +The `` should be either a path to a local directory or a Hugging +Face repo with weights stored in `safetensors` format. If you use a repo from +the Hugging Face Hub, then the model will be downloaded and cached the first +time you run it. See the [Models](#models) section for a full list of supported models. + +Run `python generate.py --help` to see all the options. + + +### Models + +The example supports Hugging Face format Mistral and Llama-style models. If the +model you want to run is not supported, file an +[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, +submit a pull request. + +Here are a few examples of Hugging Face models which work with this example: + +- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) +- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) +- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) + +Most +[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending) +and +[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending) +style models should work out of the box. + +### Convert new models + +You can convert (change the data type or quantize) models using the +`convert.py` script. This script takes a Hugging Face repo as input and outputs +a model directory (which you can optionally also upload to Hugging Face). + +For example, to make 4-bit quantized a model, run: + +``` +python convert.py --hf-model -q +``` + +For more options run: + +``` +python convert.py --help +``` + +You can upload new models to the [Hugging Face MLX +Community](https://huggingface.co/mlx-community) by specifying `--upload-name`` +to `convert.py`. diff --git a/llms/hf_llm/convert.py b/llms/hf_llm/convert.py new file mode 100644 index 00000000..2bc48fe2 --- /dev/null +++ b/llms/hf_llm/convert.py @@ -0,0 +1,174 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import copy +import glob +import json +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import transformers +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten +from models import Model, ModelArgs + + +def fetch_from_hub(hf_path: str): + model_path = snapshot_download( + repo_id=hf_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + config = transformers.AutoConfig.from_pretrained(hf_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_path, + ) + return weights, config.to_dict(), tokenizer + + +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model = Model(ModelArgs.from_dict(config)) + model.load_weights(list(weights.items())) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + estimated_size = v.size * v.dtype.size + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards + + +def upload_to_hub(path: str, name: str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"mlx-community/{name}" + + card = ModelCard.load(hf_path) + card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.text = f""" +# {name} +This model was converted to MLX format from [`{hf_path}`](). +Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. +## Use with mlx +```bash +pip install mlx +git clone https://github.com/ml-explore/mlx-examples.git +cd mlx-examples/llms/hf_llm +python generate.py --model {repo_id} --prompt "My name is" +``` +""" + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=repo_id, + repo_type="model", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert Hugging Face model to MLX format" + ) + parser.add_argument( + "--hf-path", + type=str, + help="Path to the Hugging Face model.", + ) + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="Path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q-group-size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q-bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) + parser.add_argument( + "--dtype", + help="Type to save the parameters, ignored if -q is given.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float16", + ) + parser.add_argument( + "--upload-name", + help="The name of model to upload to Hugging Face MLX Community", + type=str, + default=None, + ) + + args = parser.parse_args() + + print("[INFO] Loading") + weights, config, tokenizer = fetch_from_hub(args.hf_path) + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + if not args.quantize: + dtype = getattr(mx, args.dtype) + weights = {k: v.astype(dtype) for k, v in weights.items()} + + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + shards = make_shards(weights) + for i, shard in enumerate(shards): + mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard) + tokenizer.save_pretrained(mlx_path) + with open(mlx_path / "config.json", "w") as fid: + json.dump(config, fid, indent=4) + + if args.upload_name is not None: + upload_to_hub(mlx_path, args.upload_name) diff --git a/llms/hf_llm/generate.py b/llms/hf_llm/generate.py new file mode 100644 index 00000000..e3b1136d --- /dev/null +++ b/llms/hf_llm/generate.py @@ -0,0 +1,86 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import time + +import mlx.core as mx +import models +import transformers + + +def generate( + model: models.Model, + tokenizer: transformers.AutoTokenizer, + prompt: str, + max_tokens: int, + temp: float = 0.0, +): + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )[ + "input_ids" + ][0] + prompt = mx.array(prompt) + + tic = time.time() + tokens = [] + skip = 0 + for token, n in zip( + models.generate(prompt, model, args.temp), + range(args.max_tokens), + ): + if token == tokenizer.eos_token_id: + break + + if n == 0: + prompt_time = time.time() - tic + tic = time.time() + + tokens.append(token.item()) + # if (n + 1) % 10 == 0: + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + print(tokenizer.decode(tokens)[skip:], flush=True) + gen_time = time.time() - tic + print("=" * 10) + prompt_tps = prompt.size / prompt_time + gen_tps = (len(tokens) - 1) / gen_time + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="inference script") + parser.add_argument( + "--model", + type=str, + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + mx.random.seed(args.seed) + model, tokenizer = models.load(args.model) + generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) diff --git a/llms/hf_llm/models.py b/llms/hf_llm/models.py new file mode 100644 index 00000000..c19fb397 --- /dev/null +++ b/llms/hf_llm/models.py @@ -0,0 +1,255 @@ +# Copyright © 2023 Apple Inc. + +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer + + +@dataclass +class ModelArgs: + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.rope = nn.RoPE( + head_dim, traditional=args.rope_traditional, base=args.rope_theta + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = LlamaModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + +def load(path_or_hf_repo: str): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + model_args = ModelArgs.from_dict(config) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model = Model(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = AutoTokenizer.from_pretrained( + model_path, + ) + return model, tokenizer + + +def generate(prompt: mx.array, model: Model, temp: float = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y diff --git a/llms/hf_llm/requirements.txt b/llms/hf_llm/requirements.txt new file mode 100644 index 00000000..ccb54860 --- /dev/null +++ b/llms/hf_llm/requirements.txt @@ -0,0 +1,3 @@ +mlx>=0.0.7 +numpy +transformers diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 1b44d650..fb12282b 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -1,9 +1,9 @@ # Copyright © 2023 Apple Inc. import argparse +import glob import json import time -import glob from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index 838edd91..f75fa359 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -27,7 +27,11 @@ class Tokenizer: def encode(self, s: str) -> mx.array: return mx.array( - self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[ + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )[ "input_ids" ].squeeze(0) ) diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 1d72be41..21e22a14 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -79,9 +79,13 @@ class StableDiffusion: return x_t_prev - def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5): + def _denoising_loop( + self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5 + ): x_t = x_T - for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype): + for t, t_prev in self.sampler.timesteps( + num_steps, start_time=T, dtype=self.dtype + ): x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight) yield x_t @@ -100,7 +104,9 @@ class StableDiffusion: mx.random.seed(seed) # Get the text conditioning - conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text) + conditioning = self._get_text_conditioning( + text, n_images, cfg_weight, negative_text + ) # Create the latent variables x_T = self.sampler.sample_prior( @@ -108,7 +114,9 @@ class StableDiffusion: ) # Perform the denoising loop - yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight) + yield from self._denoising_loop( + x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight + ) def generate_latents_from_image( self, @@ -130,16 +138,20 @@ class StableDiffusion: num_steps = int(num_steps * strength) # Get the text conditioning - conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text) + conditioning = self._get_text_conditioning( + text, n_images, cfg_weight, negative_text + ) # Get the latents from the input image and add noise according to the # start time. x_0, _ = self.autoencoder.encode(image[None]) - x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:]) + x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:]) x_T = self.sampler.add_noise(x_0, mx.array(start_step)) # Perform the denoising loop - yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight) + yield from self._denoising_loop( + x_T, start_step, conditioning, num_steps, cfg_weight + ) def decode(self, x_t): x = self.autoencoder.decode(x_t) diff --git a/stable_diffusion/stable_diffusion/unet.py b/stable_diffusion/stable_diffusion/unet.py index c1a31210..d58f35fb 100644 --- a/stable_diffusion/stable_diffusion/unet.py +++ b/stable_diffusion/stable_diffusion/unet.py @@ -381,7 +381,6 @@ class UNetModel(nn.Module): ) def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None): - # Compute the time embeddings temb = self.timesteps(timestep).astype(x.dtype) temb = self.time_embedding(temb) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 2e69ec45..c5ff3e2a 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -86,8 +86,13 @@ if __name__ == "__main__": for model_name in models: model_path = f"{args.mlx_dir}/{model_name}" if not os.path.exists(model_path): - print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion") - subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True) + print( + f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion" + ) + subprocess.run( + f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", + shell=True, + ) print(f"\nModel: {model_name.upper()}") tokens = mx.array( diff --git a/whisper/convert.py b/whisper/convert.py index 48cbebc5..15a12855 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return download_target else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( @@ -132,7 +134,9 @@ def load_torch_model( alignment_heads = _ALIGNMENT_HEADS[name_or_path] name_or_path = _download(_MODELS[name_or_path], download_root) elif not Path(name_or_path).is_file(): - raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path") + raise RuntimeError( + f"Model {name_or_path} is neither found in {available_models()} nor as a local path" + ) with open(name_or_path, "rb") as fp: checkpoint = torch.load(fp) @@ -259,7 +263,9 @@ if __name__ == "__main__": ) args = parser.parse_args() - assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}" + assert ( + args.dtype in _VALID_DTYPES + ), f"dtype {args.dtype} not found in {_VALID_DTYPES}" dtype = getattr(mx, args.dtype) print("[INFO] Loading") diff --git a/whisper/test.py b/whisper/test.py index 3f81ce14..48a09152 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -10,6 +10,7 @@ from pathlib import Path import mlx.core as mx import numpy as np import torch +from convert import load_torch_model, quantize, torch_to_mlx from mlx.utils import tree_flatten import whisper @@ -17,8 +18,6 @@ import whisper.audio as audio import whisper.decoding as decoding import whisper.load_models as load_models -from convert import load_torch_model, quantize, torch_to_mlx - MODEL_NAME = "tiny" MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32" MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16" @@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False) + result = whisper.transcribe( + TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False + ) self.assertEqual( result["text"], ( @@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False) + result = whisper.transcribe( + audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False + ) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") self.assertEqual(len(result["segments"]), 77)