Merge pull request #107 from ml-explore/hf_mixtral

Use official HF for mixtral
This commit is contained in:
Awni Hannun 2023-12-14 16:57:19 -08:00 committed by GitHub
commit a3ecda22fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 32 deletions

View File

@ -315,7 +315,7 @@ def load_model(model_path):
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0: if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1] config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplier", 'rope_theta'] unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
for k in unused: for k in unused:
if k in config: if k in config:
config.pop(k) config.pop(k)

View File

@ -2,6 +2,8 @@
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
### Setup ### Setup
@ -16,37 +18,56 @@ brew install git-lfs
Download the models from Hugging Face: Download the models from Hugging Face:
For the base model use:
``` ```
git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
``` ```
After that's done, combine the files: For the instruction fine-tuned model use:
``` ```
cd mixtral-8x7b-32kseqlen/ export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth ```
Then run:
```
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
cd $MIXTRAL_MODEL/ && \
git lfs pull --include "consolidated.*.pt" && \
git lfs pull --include "tokenizer.model"
``` ```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them: MLX can read them:
``` ```
python convert.py --model_path mixtral-8x7b-32kseqlen/ python convert.py --model_path $MIXTRAL_MODEL/
``` ```
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate ### Generate
As easy as: As easy as:
``` ```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/ python mixtral.py --model_path $MIXTRAL_MODEL/
``` ```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. For more options including how to prompt the model, run:
```
python mixtral.py --help
```
For the Instruction model, make sure to follow the prompt format:
```
[INST] Instruction prompt [/INST]
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
details

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import glob
import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen/", default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.", help="The path to the Mixtral model. The MLX model weights will also be saved there.",
) )
args = parser.parse_args() args = parser.parse_args()
model_path = Path(args.model_path) model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez( with open("params.json") as fid:
str(model_path / "weights.npz"), args = json.load(fid)
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
) torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
import glob
import json import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -222,10 +223,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16): def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder) model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f: with open("params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args) model = Mixtral(model_args)
@ -255,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen", default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config", help="The path to the model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}

View File

@ -1,6 +1,7 @@
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import numpy as np import numpy as np
def replace_key(key: str) -> str: def replace_key(key: str) -> str:
if "wte.weight" in key: if "wte.weight" in key:
key = "wte.weight" key = "wte.weight"

View File

@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
logits = mlx_model(mels, tokens) logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16) self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self): def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False) options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options) result = decoding.decode(self.model, self.mels, options)

View File

@ -112,7 +112,7 @@ class DecodingOptions:
max_initial_timestamp: Optional[float] = 1.0 max_initial_timestamp: Optional[float] = 1.0
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
} }
@ -166,7 +166,8 @@ def convert(model, rules=None):
def torch_to_mlx( def torch_to_mlx(
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper: ) -> whisper.Whisper:
def convert_rblock(model, rules): def convert_rblock(model, rules):
children = dict(model.named_children()) children = dict(model.named_children())
@ -194,6 +195,6 @@ def torch_to_mlx(
def load_model( def load_model(
name: str, name: str,
download_root: str = None, download_root: str = None,
dtype : mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper: ) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype) return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@ -43,7 +43,7 @@ class ModelHolder:
model_name = None model_name = None
@classmethod @classmethod
def get_model(cls, model: str, dtype : mx.Dtype): def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name: if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype) cls.model = load_model(model, dtype=dtype)
cls.model_name = model cls.model_name = model

View File

@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype) return super().__call__(x.astype(mx.float32)).astype(x.dtype)
@ -123,7 +124,13 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
def __init__( def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
): ):
super().__init__() super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
def __init__( def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, self,
n_vocab: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
): ):
super().__init__() super().__init__()
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
for _ in range(n_layer) for _ in range(n_layer)
] ]
self.ln = LayerNorm(n_state) self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype
)
def __call__(self, x, xa, kv_cache=None): def __call__(self, x, xa, kv_cache=None):
""" """