Merge pull request #107 from ml-explore/hf_mixtral

Use official HF for mixtral
This commit is contained in:
Awni Hannun 2023-12-14 16:57:19 -08:00 committed by GitHub
commit a3ecda22fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 32 deletions

View File

@ -315,7 +315,7 @@ def load_model(model_path):
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplier", 'rope_theta']
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
for k in unused:
if k in config:
config.pop(k)

View File

@ -2,6 +2,8 @@
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
### Setup
@ -16,37 +18,56 @@ brew install git-lfs
Download the models from Hugging Face:
For the base model use:
```
git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
```
After that's done, combine the files:
For the instruction fine-tuned model use:
```
cd mixtral-8x7b-32kseqlen/
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
```
Then run:
```
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
cd $MIXTRAL_MODEL/ && \
git lfs pull --include "consolidated.*.pt" && \
git lfs pull --include "tokenizer.model"
```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model_path mixtral-8x7b-32kseqlen/
python convert.py --model_path $MIXTRAL_MODEL/
```
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate
As easy as:
```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
python mixtral.py --model_path $MIXTRAL_MODEL/
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
For more options including how to prompt the model, run:
```
python mixtral.py --help
```
For the Instruction model, make sure to follow the prompt format:
```
[INST] Instruction prompt [/INST]
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
details

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc.
import argparse
import glob
import json
import numpy as np
from pathlib import Path
import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen/",
default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
with open("params.json") as fid:
args = json.load(fid)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse
from dataclasses import dataclass
import glob
import json
import numpy as np
from pathlib import Path
@ -222,10 +223,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open("params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
@ -255,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen",
default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}

View File

@ -1,6 +1,7 @@
from transformers import AutoModelForCausalLM
import numpy as np
def replace_key(key: str) -> str:
if "wte.weight" in key:
key = "wte.weight"

View File

@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)

View File

@ -112,7 +112,7 @@ class DecodingOptions:
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = True # use fp16 for most of the calculation
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)

View File

@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}
@ -166,7 +166,8 @@ def convert(model, rules=None):
def torch_to_mlx(
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
@ -194,6 +195,6 @@ def torch_to_mlx(
def load_model(
name: str,
download_root: str = None,
dtype : mx.Dtype = mx.float32,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@ -43,7 +43,7 @@ class ModelHolder:
model_name = None
@classmethod
def get_model(cls, model: str, dtype : mx.Dtype):
def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype)
cls.model_name = model

View File

@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
@ -123,7 +124,13 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
self,
n_vocab: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
):
super().__init__()
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
for _ in range(n_layer)
]
self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype
)
def __call__(self, x, xa, kv_cache=None):
"""