use official HF for mixtral

This commit is contained in:
Awni Hannun 2023-12-14 15:30:32 -08:00
parent 09fff84a85
commit 078fed3d8d
4 changed files with 54 additions and 25 deletions

View File

@ -17,36 +17,28 @@ brew install git-lfs
Download the models from Hugging Face: Download the models from Hugging Face:
``` ```
git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/
``` cd Mixtral-8x7B-v0.1/ && \
git lfs pull --include "consolidated.*.pt" && \
After that's done, combine the files: git lfs pull --include "tokenizer.model"
```
cd mixtral-8x7b-32kseqlen/
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
``` ```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them: MLX can read them:
``` ```
python convert.py --model_path mixtral-8x7b-32kseqlen/ python convert.py --model_path Mixtral-8x7B-v0.1/
``` ```
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate ### Generate
As easy as: As easy as:
``` ```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/ python mixtral.py --model_path Mixtral-8x7B-v0.1/
``` ```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. [^mixtral]: Refer to Mistral's [blog
post](https://mistral.ai/news/mixtral-of-experts/) for more details.

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import glob
import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen/", default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.", help="The path to the Mixtral model. The MLX model weights will also be saved there.",
) )
args = parser.parse_args() args = parser.parse_args()
model_path = Path(args.model_path) model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez( with open("params.json") as fid:
str(model_path / "weights.npz"), args = json.load(fid)
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
) torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
import glob
import json import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -222,10 +223,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16): def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder) model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f: with open("params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args) model = Mixtral(model_args)
@ -255,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen", default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config", help="The path to the model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}