From 078fed3d8d8cb24c1eda31f0009edf327659b914 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 15:30:32 -0800 Subject: [PATCH 1/3] use official HF for mixtral --- mixtral/README.md | 24 ++++++++---------------- mixtral/convert.py | 44 ++++++++++++++++++++++++++++++++++++++------ mixtral/mixtral.py | 10 +++++++--- mixtral/params.json | 1 + 4 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 mixtral/params.json diff --git a/mixtral/README.md b/mixtral/README.md index b56ee767..a90f7abf 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -17,36 +17,28 @@ brew install git-lfs Download the models from Hugging Face: ``` -git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen -``` - -After that's done, combine the files: -``` -cd mixtral-8x7b-32kseqlen/ -cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/ +cd Mixtral-8x7B-v0.1/ && \ + git lfs pull --include "consolidated.*.pt" && \ + git lfs pull --include "tokenizer.model" ``` Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path mixtral-8x7b-32kseqlen/ +python convert.py --model_path Mixtral-8x7B-v0.1/ ``` The conversion script will save the converted weights in the same location. -After that's done, if you want to clean some stuff up: - -``` -rm mixtral-8x7b-32kseqlen/*.pth* -``` - ### Generate As easy as: ``` -python mixtral.py --model_path mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path Mixtral-8x7B-v0.1/ ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +[^mixtral]: Refer to Mistral's [blog + post](https://mistral.ai/news/mixtral-of-experts/) for more details. diff --git a/mixtral/convert.py b/mixtral/convert.py index e67f4453..d6ba8030 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -1,23 +1,55 @@ # Copyright © 2023 Apple Inc. import argparse +import glob +import json import numpy as np from pathlib import Path import torch +def convert(k, v, config): + v = v.to(torch.float16).numpy() + if "block_sparse_moe" not in k: + return [(k, v)] + if "gate" in k: + return [(k.replace("block_sparse_moe", "feed_forward"), v)] + + # From: layers.N.block_sparse_moe.w + # To: layers.N.experts.M.w + num_experts = args["moe"]["num_experts"] + key_path = k.split(".") + v = np.split(v, num_experts, axis=0) + if key_path[-1] == "w2": + v = [u.T for u in v] + + w_name = key_path.pop() + key_path[-1] = "feed_forward.experts" + return [ + (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) + ] + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen/", + default="Mixtral-8x7B-v0.1/", help="The path to the Mixtral model. The MLX model weights will also be saved there.", ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pth")) - np.savez( - str(model_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()}, - ) + + with open("params.json") as fid: + args = json.load(fid) + + torch_files = glob.glob(str(model_path / "consolidated.*.pt")) + torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) + for e, tf in enumerate(torch_files): + print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") + state = torch.load(tf) + new_state = {} + for k, v in state.items(): + new_state.update(convert(k, v, args)) + np.savez(str(model_path / f"weights.{e}.npz"), **new_state) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 1a9be600..59848219 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -2,6 +2,7 @@ import argparse from dataclasses import dataclass +import glob import json import numpy as np from pathlib import Path @@ -222,10 +223,13 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open("params.json", "r") as f: config = json.loads(f.read()) model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "weights.npz")) + weight_files = glob.glob(str(model_path / "weights.*.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mixtral(model_args) @@ -255,7 +259,7 @@ if __name__ == "__main__": parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen", + default="Mixtral-8x7B-v0.1", help="The path to the model weights, tokenizer, and config", ) parser.add_argument( diff --git a/mixtral/params.json b/mixtral/params.json new file mode 100644 index 00000000..f1016aa8 --- /dev/null +++ b/mixtral/params.json @@ -0,0 +1 @@ +{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}} From e434e7e5c2877535aea0aa6384fbe1f0f91f5646 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 15:40:38 -0800 Subject: [PATCH 2/3] incude instruct option --- mixtral/README.md | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/mixtral/README.md b/mixtral/README.md index a90f7abf..3b0c50d0 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -2,6 +2,8 @@ Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. +This example also supports the instruction fine-tuned Mixtral model.[^instruct] + Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. ### Setup @@ -16,9 +18,23 @@ brew install git-lfs Download the models from Hugging Face: +For the base model use: + ``` -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/ -cd Mixtral-8x7B-v0.1/ && \ +export MIXTRAL_MODEL=Mixtral-8x7B-v0.1 +``` + +For the instruction fine-tuned model use: + +``` +export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1 +``` + +Then run: + +``` +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/ +cd $MIXTRAL_MODEL/ && \ git lfs pull --include "consolidated.*.pt" && \ git lfs pull --include "tokenizer.model" ``` @@ -27,7 +43,7 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path Mixtral-8x7B-v0.1/ +python convert.py --model_path $MIXTRAL_MODEL/ ``` The conversion script will save the converted weights in the same location. @@ -37,8 +53,15 @@ The conversion script will save the converted weights in the same location. As easy as: ``` -python mixtral.py --model_path Mixtral-8x7B-v0.1/ +python mixtral.py --model_path $MIXTRAL_MODEL/ ``` -[^mixtral]: Refer to Mistral's [blog - post](https://mistral.ai/news/mixtral-of-experts/) for more details. +For more options including how to prompt the model, run: + +``` +python mixtral.py --help +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more +details From b863e7cca0405461c5239f503748ab2f62cef241 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 16:56:50 -0800 Subject: [PATCH 3/3] format --- llama/llama.py | 2 +- mixtral/README.md | 8 +++++++- phi2/convert.py | 1 + whisper/test.py | 1 - whisper/whisper/decoding.py | 2 +- whisper/whisper/load_models.py | 7 ++++--- whisper/whisper/transcribe.py | 2 +- whisper/whisper/whisper.py | 21 ++++++++++++++++++--- 8 files changed, 33 insertions(+), 11 deletions(-) diff --git a/llama/llama.py b/llama/llama.py index 9b8157b7..73eb39c5 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -315,7 +315,7 @@ def load_model(model_path): config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] if config.get("vocab_size", -1) < 0: config["vocab_size"] = weights["output.weight"].shape[-1] - unused = ["multiple_of", "ffn_dim_multiplier", 'rope_theta'] + unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] for k in unused: if k in config: config.pop(k) diff --git a/mixtral/README.md b/mixtral/README.md index 3b0c50d0..9194979e 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -62,6 +62,12 @@ For more options including how to prompt the model, run: python mixtral.py --help ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +For the Instruction model, make sure to follow the prompt format: + +``` +[INST] Instruction prompt [/INST] +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details. [^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more details diff --git a/phi2/convert.py b/phi2/convert.py index 4c625a6e..5aa07dce 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -1,6 +1,7 @@ from transformers import AutoModelForCausalLM import numpy as np + def replace_key(key: str) -> str: if "wte.weight" in key: key = "wte.weight" diff --git a/whisper/test.py b/whisper/test.py index 79f233ba..3e7630a9 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase): logits = mlx_model(mels, tokens) self.assertEqual(logits.dtype, mx.float16) - def test_decode_lang(self): options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index 7c7c4a93..d5025444 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -112,7 +112,7 @@ class DecodingOptions: max_initial_timestamp: Optional[float] = 1.0 # implementation details - fp16: bool = True # use fp16 for most of the calculation + fp16: bool = True # use fp16 for most of the calculation @dataclass(frozen=True) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 58cef9ac..ffdccf44 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = { "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", - "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } @@ -166,7 +166,8 @@ def convert(model, rules=None): def torch_to_mlx( - torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, + torch_model: torch_whisper.Whisper, + dtype: mx.Dtype = mx.float16, ) -> whisper.Whisper: def convert_rblock(model, rules): children = dict(model.named_children()) @@ -194,6 +195,6 @@ def torch_to_mlx( def load_model( name: str, download_root: str = None, - dtype : mx.Dtype = mx.float32, + dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: return torch_to_mlx(load_torch_model(name, download_root), dtype) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 3172bdb3..06f3c9ea 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -43,7 +43,7 @@ class ModelHolder: model_name = None @classmethod - def get_model(cls, model: str, dtype : mx.Dtype): + def get_model(cls, model: str, dtype: mx.Dtype): if cls.model is None or model != cls.model_name: cls.model = load_model(model, dtype=dtype) cls.model_name = model diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index bca69946..8ee6d7d9 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) + class LayerNorm(nn.LayerNorm): def __call__(self, x: mx.array) -> mx.array: return super().__call__(x.astype(mx.float32)).astype(x.dtype) @@ -123,7 +124,13 @@ class ResidualAttentionBlock(nn.Module): class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) @@ -148,7 +155,13 @@ class AudioEncoder(nn.Module): class TextDecoder(nn.Module): def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_vocab: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() @@ -160,7 +173,9 @@ class TextDecoder(nn.Module): for _ in range(n_layer) ] self.ln = LayerNorm(n_state) - self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype( + dtype + ) def __call__(self, x, xa, kv_cache=None): """