From 078fed3d8d8cb24c1eda31f0009edf327659b914 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 15:30:32 -0800 Subject: [PATCH] use official HF for mixtral --- mixtral/README.md | 24 ++++++++---------------- mixtral/convert.py | 44 ++++++++++++++++++++++++++++++++++++++------ mixtral/mixtral.py | 10 +++++++--- mixtral/params.json | 1 + 4 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 mixtral/params.json diff --git a/mixtral/README.md b/mixtral/README.md index b56ee767..a90f7abf 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -17,36 +17,28 @@ brew install git-lfs Download the models from Hugging Face: ``` -git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen -``` - -After that's done, combine the files: -``` -cd mixtral-8x7b-32kseqlen/ -cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/ +cd Mixtral-8x7B-v0.1/ && \ + git lfs pull --include "consolidated.*.pt" && \ + git lfs pull --include "tokenizer.model" ``` Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path mixtral-8x7b-32kseqlen/ +python convert.py --model_path Mixtral-8x7B-v0.1/ ``` The conversion script will save the converted weights in the same location. -After that's done, if you want to clean some stuff up: - -``` -rm mixtral-8x7b-32kseqlen/*.pth* -``` - ### Generate As easy as: ``` -python mixtral.py --model_path mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path Mixtral-8x7B-v0.1/ ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +[^mixtral]: Refer to Mistral's [blog + post](https://mistral.ai/news/mixtral-of-experts/) for more details. diff --git a/mixtral/convert.py b/mixtral/convert.py index e67f4453..d6ba8030 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -1,23 +1,55 @@ # Copyright © 2023 Apple Inc. import argparse +import glob +import json import numpy as np from pathlib import Path import torch +def convert(k, v, config): + v = v.to(torch.float16).numpy() + if "block_sparse_moe" not in k: + return [(k, v)] + if "gate" in k: + return [(k.replace("block_sparse_moe", "feed_forward"), v)] + + # From: layers.N.block_sparse_moe.w + # To: layers.N.experts.M.w + num_experts = args["moe"]["num_experts"] + key_path = k.split(".") + v = np.split(v, num_experts, axis=0) + if key_path[-1] == "w2": + v = [u.T for u in v] + + w_name = key_path.pop() + key_path[-1] = "feed_forward.experts" + return [ + (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) + ] + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen/", + default="Mixtral-8x7B-v0.1/", help="The path to the Mixtral model. The MLX model weights will also be saved there.", ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pth")) - np.savez( - str(model_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()}, - ) + + with open("params.json") as fid: + args = json.load(fid) + + torch_files = glob.glob(str(model_path / "consolidated.*.pt")) + torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) + for e, tf in enumerate(torch_files): + print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") + state = torch.load(tf) + new_state = {} + for k, v in state.items(): + new_state.update(convert(k, v, args)) + np.savez(str(model_path / f"weights.{e}.npz"), **new_state) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 1a9be600..59848219 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -2,6 +2,7 @@ import argparse from dataclasses import dataclass +import glob import json import numpy as np from pathlib import Path @@ -222,10 +223,13 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open("params.json", "r") as f: config = json.loads(f.read()) model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "weights.npz")) + weight_files = glob.glob(str(model_path / "weights.*.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mixtral(model_args) @@ -255,7 +259,7 @@ if __name__ == "__main__": parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen", + default="Mixtral-8x7B-v0.1", help="The path to the model weights, tokenizer, and config", ) parser.add_argument( diff --git a/mixtral/params.json b/mixtral/params.json new file mode 100644 index 00000000..f1016aa8 --- /dev/null +++ b/mixtral/params.json @@ -0,0 +1 @@ +{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}