Merge pull request #75 from ml-explore/mixtral

Mixtral
This commit is contained in:
Awni Hannun 2023-12-12 10:41:25 -08:00 committed by GitHub
commit 9b3f64d196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 398 additions and 14 deletions

View File

@ -2,7 +2,8 @@
An example of generating text with Mistral using MLX.
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
Mistral 7B is one of the top large language models in its size class. It is
also fully open source with a permissive license[^1].
### Setup
@ -25,6 +26,8 @@ Then, convert the weights with:
python convert.py
```
The conversion script will save the converted weights in the same location.
### Run
Once you've converted the weights to MLX format, you can generate text with
@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
Run `python mistral.py --help` for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
and [github repository](https://github.com/mistralai/mistral-src) for more
details.

View File

@ -2,26 +2,23 @@
import argparse
import numpy as np
from pathlib import Path
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch_model",
"--model_path",
type=str,
default="mistral-7B-v0.1/consolidated.00.pth",
help="The path to the torch model weights",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
help="The path to store the mlx model weights",
default="mistral-7B-v0.1/",
help="The path to the Mistral model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
state = torch.load(args.torch_model)
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)

View File

@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
config = json.loads(f.read())
config.pop("sliding_window")
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args)

52
mixtral/README.md Normal file
View File

@ -0,0 +1,52 @@
## Mixtral 8x7B
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
### Setup
Install [Git Large File
Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage).
For example with Homebrew:
```
brew install git-lfs
```
Download the models from HuggingFace:
```
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
```
After that's done, combine the files:
```
cd mixtral-8x7b-32kseqlen/
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
```
Now from `mlx-exmaples/mixtral` conver and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model_path mixtral-8x7b-32kseqlen/
```
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate
As easy as:
```
python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.

23
mixtral/convert.py Normal file
View File

@ -0,0 +1,23 @@
# Copyright © 2023 Apple Inc.
import argparse
import numpy as np
from pathlib import Path
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pt"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
)

303
mixtral/mixtral.py Normal file
View File

@ -0,0 +1,303 @@
# Copyright © 2023 Apple Inc.
import argparse
from dataclasses import dataclass
import json
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict = None
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class MOETransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = MOEFeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Mixtral(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
@property
def eos_id(self) -> int:
return self._model.eos_id()
@property
def pad_id(self) -> int:
return self._model.pad_id()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
model.update(weights)
return model, tokenizer
def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mixtral inference script")
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

4
mixtral/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
mlx
sentencepiece
torch
numpy