mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
commit
0bf5d0e3bc
15
README.md
15
README.md
@ -1,6 +1,17 @@
|
||||
# mlx-examples
|
||||
# MLX Examples
|
||||
|
||||
Examples using the MLX framework.
|
||||
This repo contains a variety of standalone examples using the [MLX
|
||||
framework](https://github.com/ml-explore/mlx).
|
||||
|
||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||
|
||||
Some more useful examples include:
|
||||
|
||||
- [Transformer language model](transformer_lm) training.
|
||||
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral)
|
||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||
|
||||
## Contributing
|
||||
|
||||
|
1
mistral/.gitignore
vendored
Normal file
1
mistral/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
mistral-7B-v0.1/
|
39
mistral/README.md
Normal file
39
mistral/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# Mistral
|
||||
|
||||
An example of generating text with Mistral using MLX.
|
||||
|
||||
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download the model and tokenizer:
|
||||
|
||||
```
|
||||
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
||||
tar -xf mistral-7B-v0.1.tar
|
||||
```
|
||||
|
||||
Then, convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can generate text with
|
||||
the Mistral model:
|
||||
|
||||
```
|
||||
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
||||
```
|
||||
|
||||
Run `python mistral.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
27
mistral/convert.py
Normal file
27
mistral/convert.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--torch_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/consolidated.00.pth",
|
||||
help="The path to the torch model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
|
||||
help="The path to store the mlx model weights",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_model)
|
||||
np.savez(
|
||||
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
275
mistral/mistral.py
Normal file
275
mistral/mistral.py
Normal file
@ -0,0 +1,275 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Mistral(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h)), cache
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._model.pad_id()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("sliding_window")
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mistral(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mistral inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1",
|
||||
help="The path to the model weights and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
3
mistral/requirements.txt
Normal file
3
mistral/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
118
mistral/test.py
Normal file
118
mistral/test.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
import mistral
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
def test_model(self):
|
||||
vocab_size = 100
|
||||
L = 32
|
||||
args = mistral.ModelArgs(
|
||||
dim=128,
|
||||
n_layers=2,
|
||||
head_dim=32,
|
||||
hidden_dim=256,
|
||||
n_heads=4,
|
||||
n_kv_heads=4,
|
||||
norm_eps=1e-3,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
model = mistral.Mistral(args)
|
||||
inputs = mx.random.randint(0, vocab_size, (L,))
|
||||
logits, cache = model(inputs[None])
|
||||
self.assertEqual(logits.shape, [1, L, vocab_size])
|
||||
self.assertEqual(logits.dtype, mx.float32)
|
||||
self.assertEqual(len(cache), args.n_layers)
|
||||
|
||||
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
|
||||
model.update(params)
|
||||
logits, _ = model(inputs[None])
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
def test_generate(self):
|
||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
||||
prompt = mx.array(tokenizer.encode("This is a test"))
|
||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
|
||||
mx.eval(tokens)
|
||||
tokens = [t.item() for t in tokens]
|
||||
expected = [
|
||||
302,
|
||||
272,
|
||||
11843,
|
||||
11837,
|
||||
1587,
|
||||
28723,
|
||||
851,
|
||||
349,
|
||||
865,
|
||||
264,
|
||||
1369,
|
||||
28723,
|
||||
13,
|
||||
13,
|
||||
3381,
|
||||
456,
|
||||
654,
|
||||
264,
|
||||
1353,
|
||||
11843,
|
||||
28725,
|
||||
368,
|
||||
682,
|
||||
347,
|
||||
2240,
|
||||
767,
|
||||
298,
|
||||
511,
|
||||
28723,
|
||||
13,
|
||||
]
|
||||
self.assertEqual(tokens, expected)
|
||||
|
||||
def benchmark(self):
|
||||
import time
|
||||
|
||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
||||
prompt = mx.random.randint(0, model.vocab_size, (128,))
|
||||
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
|
||||
tic = time.time()
|
||||
its = 5
|
||||
for _ in range(its):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
toc = time.time()
|
||||
tps = its * prompt.size / (toc - tic)
|
||||
print(f"Prompt processing: {tps:.2f} tokens per second")
|
||||
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
|
||||
mx.eval(tokens)
|
||||
|
||||
time_total = 0.0
|
||||
its = 2
|
||||
for _ in range(its):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
tic = time.time()
|
||||
tokens = [t for t, _ in zip(generator, range(100))]
|
||||
mx.eval(tokens)
|
||||
time_total += time.time() - tic
|
||||
|
||||
tps = len(tokens) * its / time_total
|
||||
print(f"Token generation: {tps:.3f} tokens per second")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user