Add llms subdir + update README (#145)

* add llms subdir + update README

* nits

* use same pre-commit as mlx

* update readmes a bit

* format
This commit is contained in:
Awni Hannun
2023-12-20 10:22:25 -08:00
committed by GitHub
parent 3c64f1a1dc
commit b3f23a7191
62 changed files with 164 additions and 146 deletions

1
llms/mistral/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
mistral-7B-v0.1/

50
llms/mistral/README.md Normal file
View File

@@ -0,0 +1,50 @@
# Mistral
An example of generating text with Mistral using MLX.
Mistral 7B is one of the top large language models in its size class. It is
also fully open source with a permissive license[^1].
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download the model and tokenizer:
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
```
Then, convert the weights with:
```
python convert.py
```
The conversion script will save the converted weights in the same location.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
> Face and skip the conversion step.
### Run
Once you've converted the weights to MLX format, you can generate text with
the Mistral model:
```
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
```
Run `python mistral.py --help` for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
and [github repository](https://github.com/mistralai/mistral-src) for more
details.

32
llms/mistral/convert.py Normal file
View File

@@ -0,0 +1,32 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
from pathlib import Path
import numpy as np
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--model_path",
type=str,
default="mistral-7B-v0.1/",
help="The path to the Mistral model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)
# Save config.json with model_type
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
config["model_type"] = "mistral"
with open(model_path / "config.json", "w") as f:
json.dump(config, f, indent=4)

282
llms/mistral/mistral.py Normal file
View File

@@ -0,0 +1,282 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Mistral(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
@property
def eos_id(self) -> int:
return self._model.eos_id()
@property
def pad_id(self) -> int:
return self._model.pad_id()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window", None)
config.pop("model_type", None)
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args)
model.update(weights)
return model, tokenizer
def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mistral inference script")
parser.add_argument(
"--model_path",
type=str,
default="mistral-7B-v0.1",
help="The path to the model weights and tokenizer",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
)
parser.add_argument(
"--tokens_per_eval",
help="The batch size of tokens to generate.",
type=int,
default=10,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")

View File

@@ -0,0 +1,4 @@
mlx
sentencepiece
torch
numpy

117
llms/mistral/test.py Normal file
View File

@@ -0,0 +1,117 @@
# Copyright © 2023 Apple Inc.
import unittest
import mistral
import mlx.core as mx
from mlx.utils import tree_map
class TestMistral(unittest.TestCase):
def test_model(self):
vocab_size = 100
L = 32
args = mistral.ModelArgs(
dim=128,
n_layers=2,
head_dim=32,
hidden_dim=256,
n_heads=4,
n_kv_heads=4,
norm_eps=1e-3,
vocab_size=vocab_size,
)
model = mistral.Mistral(args)
inputs = mx.random.randint(0, vocab_size, (L,))
logits, cache = model(inputs[None])
self.assertEqual(logits.shape, [1, L, vocab_size])
self.assertEqual(logits.dtype, mx.float32)
self.assertEqual(len(cache), args.n_layers)
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
model.update(params)
logits, _ = model(inputs[None])
self.assertEqual(logits.dtype, mx.float16)
def test_generate(self):
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.array(tokenizer.encode("This is a test"))
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
mx.eval(tokens)
tokens = [t.item() for t in tokens]
expected = [
302,
272,
11843,
11837,
1587,
28723,
851,
349,
865,
264,
1369,
28723,
13,
13,
3381,
456,
654,
264,
1353,
11843,
28725,
368,
682,
347,
2240,
767,
298,
511,
28723,
13,
]
self.assertEqual(tokens, expected)
def benchmark(self):
import time
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.random.randint(0, model.vocab_size, (128,))
# warmup
for _ in range(2):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
its = 5
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
toc = time.time()
tps = its * prompt.size / (toc - tic)
print(f"Prompt processing: {tps:.2f} tokens per second")
# warmup
for _ in range(2):
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
mx.eval(tokens)
time_total = 0.0
its = 2
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
tokens = [t for t, _ in zip(generator, range(100))]
mx.eval(tokens)
time_total += time.time() - tic
tps = len(tokens) * its / time_total
print(f"Token generation: {tps:.3f} tokens per second")
if __name__ == "__main__":
unittest.main()