This commit is contained in:
Awni Hannun 2023-12-05 11:02:52 -08:00
parent 1900564f59
commit b7840a4721
6 changed files with 469 additions and 0 deletions

1
mistral/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
mistral-7B-v0.1/

39
mistral/README.md Normal file
View File

@ -0,0 +1,39 @@
# Mistral
An example of generating text with Mistral using MLX.
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download the model and tokenizer.
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
```
Then, convert the weights with:
```
python convert.py <path_to_torch_weights> mlx_mistral_weights.npz
```
### Run
Once you've converted the weights to MLX format, you can interact with the
Mistral model:
```
python mistral.py mlx_mistral.npz tokenizer.model "hello"
```
Run `python mistral.py --help` for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.

27
mistral/convert.py Normal file
View File

@ -0,0 +1,27 @@
# Copyright © 2023 Apple Inc.
import argparse
import numpy as np
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch_model",
type=str,
default="mistral-7B-v0.1/consolidated.00.pth",
help="The path to the torch model weights",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
help="The path to store the mlx model weights",
)
args = parser.parse_args()
state = torch.load(args.torch_model)
np.savez(
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
)

281
mistral/mistral.py Normal file
View File

@ -0,0 +1,281 @@
# Copyright © 2023 Apple Inc.
import argparse
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# queries = queries.reshape(B, self.n_kv_heads, self.repeats, L, -1)
# scores = (queries * self.scale) @ mx.expand_dims(keys.transpose(0, 1, 3, 2), 2)
# if mask is not None:
# scores += mask
# scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
# output = (scores @ mx.expand_dims(values, 2)).reshape(B, self.n_heads, L, -1)
# output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Mistral(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
@property
def eos_id(self) -> int:
return self._model.eos_id()
@property
def pad_id(self) -> int:
return self._model.pad_id()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window")
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args)
model.update(weights)
return model, tokenizer
def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mistral inference script")
parser.add_argument(
"--model_path",
type=str,
default="mistral-7B-v0.1",
help="The path to the model weights and tokenizer",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")

3
mistral/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
mlx
sentencepiece
torch

118
mistral/test.py Normal file
View File

@ -0,0 +1,118 @@
# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
import mistral
class TestMistral(unittest.TestCase):
def test_model(self):
vocab_size = 100
L = 32
args = mistral.ModelArgs(
dim=128,
n_layers=2,
head_dim=32,
hidden_dim=256,
n_heads=4,
n_kv_heads=4,
norm_eps=1e-3,
vocab_size=vocab_size,
)
model = mistral.Mistral(args)
inputs = mx.random.randint(0, vocab_size, (L,))
logits, cache = model(inputs[None])
self.assertEqual(logits.shape, [1, L, vocab_size])
self.assertEqual(logits.dtype, mx.float32)
self.assertEqual(len(cache), args.n_layers)
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
model.update(params)
logits, _ = model(inputs[None])
self.assertEqual(logits.dtype, mx.float16)
def test_generate(self):
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.array(tokenizer.encode("This is a test"))
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
mx.eval(tokens)
tokens = [t.item() for t in tokens]
expected = [
302,
272,
11843,
11837,
1587,
28723,
851,
349,
865,
264,
1369,
28723,
13,
13,
3381,
456,
654,
264,
1353,
11843,
28725,
368,
682,
347,
2240,
767,
298,
511,
28723,
13,
]
self.assertEqual(tokens, expected)
def benchmark(self):
import time
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.random.randint(0, model.vocab_size, (128,))
# warmup
for _ in range(2):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
its = 5
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
toc = time.time()
tps = its * prompt.size / (toc - tic)
print(f"Prompt processing: {tps:.2f} tokens per second")
# warmup
for _ in range(2):
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
mx.eval(tokens)
time_total = 0.0
its = 2
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
tokens = [t for t, _ in zip(generator, range(100))]
mx.eval(tokens)
time_total += time.time() - tic
tps = len(tokens) * its / time_total
print(f"Token generation: {tps:.3f} tokens per second")
if __name__ == "__main__":
unittest.main()