diff --git a/README.md b/README.md index 3b67ffc8..adcdc45e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,17 @@ -# mlx-examples +# MLX Examples -Examples using the MLX framework. +This repo contains a variety of standalone examples using the [MLX +framework](https://github.com/ml-explore/mlx). + +The [MNIST](mnist) example is a good starting point to learn how to use MLX. + +Some more useful examples include: + +- [Transformer language model](transformer_lm) training. +- Large scale text generation with [LLaMA](llama) or [Mistral](mistral) +- Parameter efficient fine-tuning with [LoRA](lora). +- Generating images with [Stable Diffusion](stable_diffusion). +- Speech recognition with [OpenAI's Whisper](whisper). ## Contributing diff --git a/mistral/.gitignore b/mistral/.gitignore new file mode 100644 index 00000000..dc4e84a2 --- /dev/null +++ b/mistral/.gitignore @@ -0,0 +1 @@ +mistral-7B-v0.1/ diff --git a/mistral/README.md b/mistral/README.md new file mode 100644 index 00000000..1bbb385d --- /dev/null +++ b/mistral/README.md @@ -0,0 +1,39 @@ +# Mistral + +An example of generating text with Mistral using MLX. + +Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1]. + +### Setup + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +Next, download the model and tokenizer: + +``` +curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar +tar -xf mistral-7B-v0.1.tar +``` + +Then, convert the weights with: + +``` +python convert.py +``` + +### Run + +Once you've converted the weights to MLX format, you can generate text with +the Mistral model: + +``` +python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 +``` + +Run `python mistral.py --help` for more details. + +[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. diff --git a/mistral/convert.py b/mistral/convert.py new file mode 100644 index 00000000..e170aed7 --- /dev/null +++ b/mistral/convert.py @@ -0,0 +1,27 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import numpy as np +import torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") + parser.add_argument( + "--torch_model", + type=str, + default="mistral-7B-v0.1/consolidated.00.pth", + help="The path to the torch model weights", + ) + parser.add_argument( + "--mlx_model", + type=str, + default="mistral-7B-v0.1/mlx_mistral_7b.npz", + help="The path to store the mlx model weights", + ) + args = parser.parse_args() + + state = torch.load(args.torch_model) + np.savez( + args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()} + ) diff --git a/mistral/mistral.py b/mistral/mistral.py new file mode 100644 index 00000000..6a6447bc --- /dev/null +++ b/mistral/mistral.py @@ -0,0 +1,275 @@ +# Copyright © 2023 Apple Inc. + +import argparse +from dataclasses import dataclass +import json +from pathlib import Path +from typing import Optional, Tuple, List +from sentencepiece import SentencePieceProcessor + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map, tree_unflatten + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Mistral(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.tok_embeddings(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.output(self.norm(h)), cache + + +class Tokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + self._sep = "▁" + assert self._model.vocab_size() == self._model.get_piece_size() + + @property + def eos_id(self) -> int: + return self._model.eos_id() + + @property + def pad_id(self) -> int: + return self._model.pad_id() + + def encode(self, s: str) -> List[int]: + return [self._model.bos_id(), *self._model.encode(s)] + + def decode(self, t: List[int]) -> str: + out = self._model.decode(t) + if t and self._model.id_to_piece(t[0])[0] == self._sep: + return " " + out + return out + + +def load_model(folder: str, dtype=mx.float16): + model_path = Path(folder) + tokenizer = Tokenizer(str(model_path / "tokenizer.model")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + config.pop("sliding_window") + model_args = ModelArgs(**config) + weights = mx.load(str(model_path / "mlx_mistral_7b.npz")) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model = Mistral(model_args) + model.update(weights) + return model, tokenizer + + +def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt[None]) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache) + y = sample(logits.squeeze(1)) + yield y + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mistral inference script") + parser.add_argument( + "--model_path", + type=str, + default="mistral-7B-v0.1", + help="The path to the model weights and tokenizer", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=1.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + + mx.random.seed(args.seed) + print("[INFO] Loading model from disk.") + model, tokenizer = load_model(args.model_path) + + print("[INFO] Starting generation...") + + print(args.prompt, end="", flush=True) + prompt = mx.array(tokenizer.encode(args.prompt)) + tokens = [] + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) + print("------") diff --git a/mistral/requirements.txt b/mistral/requirements.txt new file mode 100644 index 00000000..7111f1d4 --- /dev/null +++ b/mistral/requirements.txt @@ -0,0 +1,3 @@ +mlx +sentencepiece +torch diff --git a/mistral/test.py b/mistral/test.py new file mode 100644 index 00000000..bc52d49f --- /dev/null +++ b/mistral/test.py @@ -0,0 +1,118 @@ +# Copyright © 2023 Apple Inc. + +import unittest + +import mlx.core as mx +from mlx.utils import tree_map + +import mistral + + +class TestMistral(unittest.TestCase): + def test_model(self): + vocab_size = 100 + L = 32 + args = mistral.ModelArgs( + dim=128, + n_layers=2, + head_dim=32, + hidden_dim=256, + n_heads=4, + n_kv_heads=4, + norm_eps=1e-3, + vocab_size=vocab_size, + ) + + model = mistral.Mistral(args) + inputs = mx.random.randint(0, vocab_size, (L,)) + logits, cache = model(inputs[None]) + self.assertEqual(logits.shape, [1, L, vocab_size]) + self.assertEqual(logits.dtype, mx.float32) + self.assertEqual(len(cache), args.n_layers) + + params = tree_map(lambda p: p.astype(mx.float16), model.parameters()) + model.update(params) + logits, _ = model(inputs[None]) + self.assertEqual(logits.dtype, mx.float16) + + def test_generate(self): + model, tokenizer = mistral.load_model("mistral-7B-v0.1") + prompt = mx.array(tokenizer.encode("This is a test")) + tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))] + mx.eval(tokens) + tokens = [t.item() for t in tokens] + expected = [ + 302, + 272, + 11843, + 11837, + 1587, + 28723, + 851, + 349, + 865, + 264, + 1369, + 28723, + 13, + 13, + 3381, + 456, + 654, + 264, + 1353, + 11843, + 28725, + 368, + 682, + 347, + 2240, + 767, + 298, + 511, + 28723, + 13, + ] + self.assertEqual(tokens, expected) + + def benchmark(self): + import time + + model, tokenizer = mistral.load_model("mistral-7B-v0.1") + prompt = mx.random.randint(0, model.vocab_size, (128,)) + + # warmup + for _ in range(2): + generator = mistral.generate(prompt, model) + mx.eval(next(generator)) + + tic = time.time() + its = 5 + for _ in range(its): + generator = mistral.generate(prompt, model) + mx.eval(next(generator)) + toc = time.time() + tps = its * prompt.size / (toc - tic) + print(f"Prompt processing: {tps:.2f} tokens per second") + + # warmup + for _ in range(2): + tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))] + mx.eval(tokens) + + time_total = 0.0 + its = 2 + for _ in range(its): + generator = mistral.generate(prompt, model) + mx.eval(next(generator)) + tic = time.time() + tokens = [t for t, _ in zip(generator, range(100))] + mx.eval(tokens) + time_total += time.time() - tic + + tps = len(tokens) * its / time_total + print(f"Token generation: {tps:.3f} tokens per second") + + +if __name__ == "__main__": + unittest.main()