From 738448c2d4d55585cbb178c8b54f55664a8cbe89 Mon Sep 17 00:00:00 2001 From: Yifan Date: Mon, 25 Dec 2023 22:10:01 +0800 Subject: [PATCH 1/6] QWEN: Fix unsupported ScalarType BFloat16 (#187) Fix unsupported ScalarType BFloat16. --- llms/qwen/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 88135208..e91be263 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -60,7 +60,7 @@ def convert(args): args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} config = model.config.to_dict() if args.quantize: From 2bd20ef0e06a7705f5230b90078668612f22f468 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Mon, 25 Dec 2023 20:19:43 +0100 Subject: [PATCH 2/6] shard llama model after conversion and unshard on loading (#174) --- llms/llama/convert.py | 24 ++++++++++++++++++++++-- llms/llama/llama.py | 20 ++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/llms/llama/convert.py b/llms/llama/convert.py index dae337ee..6f5285c3 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -15,7 +15,6 @@ import torch from llama import Llama, ModelArgs, sanitize_config from mlx.utils import tree_flatten, tree_map, tree_unflatten - def llama(model_path): SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"] @@ -140,6 +139,22 @@ def quantize(weights, config, args): return quantized_weights, quantized_config +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + # TODO: simplify to v.nbytes as soon as mx.array exposes it + estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument( @@ -200,6 +215,11 @@ if __name__ == "__main__": str(torch_path / "tokenizer.model"), str(mlx_path / "tokenizer.model"), ) - np.savez(str(mlx_path / "weights.npz"), **weights) + shards = make_shards(weights) + if len(shards) == 1: + np.savez(str(mlx_path / f"weights.npz"), **shards[0]) + else: + for i, shard in enumerate(shards): + np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard) with open(mlx_path / "config.json", "w") as fid: json.dump(params, fid, indent=4) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index d684ed6d..97ec4101 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -3,6 +3,7 @@ import argparse import json import time +import glob from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple @@ -330,7 +331,23 @@ def sanitize_config(config, weights): def load_model(model_path): model_path = Path(model_path) - weights = mx.load(str(model_path / "weights.npz")) + + unsharded_weights_path = Path(model_path / "weights.npz") + if unsharded_weights_path.is_file(): + print("[INFO] Loading model from {}.".format(unsharded_weights_path)) + weights = mx.load(str(unsharded_weights_path)) + else: + sharded_weights_glob = str(model_path / "weights.*.npz") + weight_files = glob.glob(sharded_weights_glob) + print("[INFO] Loading model from {}.".format(sharded_weights_glob)) + + if len(weight_files) == 0: + raise FileNotFoundError("No weights found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + with open(model_path / "config.json", "r") as f: config = sanitize_config(json.loads(f.read()), weights) quantization = config.pop("quantization", None) @@ -373,7 +390,6 @@ if __name__ == "__main__": mx.random.seed(args.seed) - print("[INFO] Loading model from disk.") model, tokenizer = load_model(args.model_path) if args.few_shot: few_shot_generate(args) From a516f4635d048ad3c110dd4af37b7bb11faa38dd Mon Sep 17 00:00:00 2001 From: Sushant Date: Tue, 26 Dec 2023 23:02:43 +0530 Subject: [PATCH 3/6] Fixed the return type for the __call__ method in Attention (#190) --- llms/llama/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 97ec4101..1b44d650 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -67,7 +67,7 @@ class Attention(nn.Module): x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: B, L, D = x.shape queries, keys, values = self.wq(x), self.wk(x), self.wv(x) From 50fceb1a284bd6c09abc81cda24bf6d0a2505220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Morav=C4=8D=C3=ADk?= Date: Wed, 27 Dec 2023 00:18:59 +0100 Subject: [PATCH 4/6] fix: Add numpy to CIFAR's requirements.txt (#192) --- cifar/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cifar/requirements.txt b/cifar/requirements.txt index 6ff78a64..c4c2e575 100644 --- a/cifar/requirements.txt +++ b/cifar/requirements.txt @@ -1,2 +1,3 @@ mlx -mlx-data \ No newline at end of file +mlx-data +numpy \ No newline at end of file From 78d207fe27987c981c2e660c74ee763a6be85fc5 Mon Sep 17 00:00:00 2001 From: Sunbir Gill Date: Wed, 27 Dec 2023 16:11:10 -0500 Subject: [PATCH 5/6] Fix generate example in README (#197) --- llms/mixtral/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mixtral/README.md b/llms/mixtral/README.md index 49e50c91..e4673530 100644 --- a/llms/mixtral/README.md +++ b/llms/mixtral/README.md @@ -61,7 +61,7 @@ the converted `weights.npz`, `tokenizer.model`, and `config.json` there. As easy as: ``` -python mixtral.py --model-path $MIXTRAL_MODEL/ +python mixtral.py --model-path mlx_model ``` For more options including how to prompt the model, run: From e1e56a625b801109e82d413cfa6e7e6db0f6a250 Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Thu, 28 Dec 2023 20:29:39 +0100 Subject: [PATCH 6/6] Fix benchmark (#200) --- whisper/benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 10025952..877bb4f0 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -41,8 +41,8 @@ def decode(model, mels): return decoding.decode(model, mels) -def everything(): - return transcribe(audio_file) +def everything(model_name): + return transcribe(audio_file, model=model_name) if __name__ == "__main__": @@ -99,6 +99,6 @@ if __name__ == "__main__": print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels) print(f"Decode time {decode_time:.3f}") - everything_time = timer(everything) + everything_time = timer(everything, model_name) print(f"Everything time {everything_time:.3f}") print(f"\n{'-----' * 10}\n")