use official HF for mixtral

This commit is contained in:
Awni Hannun 2023-12-14 15:30:32 -08:00
parent 09fff84a85
commit 078fed3d8d
4 changed files with 54 additions and 25 deletions

View File

@ -17,36 +17,28 @@ brew install git-lfs
Download the models from Hugging Face:
```
git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
```
After that's done, combine the files:
```
cd mixtral-8x7b-32kseqlen/
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/
cd Mixtral-8x7B-v0.1/ && \
git lfs pull --include "consolidated.*.pt" && \
git lfs pull --include "tokenizer.model"
```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model_path mixtral-8x7b-32kseqlen/
python convert.py --model_path Mixtral-8x7B-v0.1/
```
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate
As easy as:
```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
python mixtral.py --model_path Mixtral-8x7B-v0.1/
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
[^mixtral]: Refer to Mistral's [blog
post](https://mistral.ai/news/mixtral-of-experts/) for more details.

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc.
import argparse
import glob
import json
import numpy as np
from pathlib import Path
import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen/",
default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
with open("params.json") as fid:
args = json.load(fid)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse
from dataclasses import dataclass
import glob
import json
import numpy as np
from pathlib import Path
@ -222,10 +223,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open("params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
@ -255,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen",
default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}