mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 02:48:07 +08:00 
			
		
		
		
	mistral
This commit is contained in:
		
							
								
								
									
										1
									
								
								mistral/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								mistral/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| mistral-7B-v0.1/ | ||||
							
								
								
									
										39
									
								
								mistral/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								mistral/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| # Mistral  | ||||
|  | ||||
| An example of generating text with Mistral using MLX. | ||||
|  | ||||
| Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1]. | ||||
|  | ||||
| ### Setup | ||||
|  | ||||
| Install the dependencies: | ||||
|  | ||||
| ``` | ||||
| pip install -r requirements.txt | ||||
| ``` | ||||
|  | ||||
| Next, download the model and tokenizer. | ||||
|  | ||||
| ``` | ||||
| curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar | ||||
| tar -xf mistral-7B-v0.1.tar | ||||
| ``` | ||||
|  | ||||
| Then, convert the weights with: | ||||
|  | ||||
| ``` | ||||
| python convert.py <path_to_torch_weights> mlx_mistral_weights.npz | ||||
| ``` | ||||
|  | ||||
| ### Run | ||||
|  | ||||
| Once you've converted the weights to MLX format, you can interact with the | ||||
| Mistral model: | ||||
|  | ||||
| ``` | ||||
| python mistral.py mlx_mistral.npz tokenizer.model "hello" | ||||
| ``` | ||||
|  | ||||
| Run `python mistral.py --help` for more details. | ||||
|  | ||||
| [^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. | ||||
							
								
								
									
										27
									
								
								mistral/convert.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								mistral/convert.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import argparse | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") | ||||
|     parser.add_argument( | ||||
|         "--torch_model", | ||||
|         type=str, | ||||
|         default="mistral-7B-v0.1/consolidated.00.pth", | ||||
|         help="The path to the torch model weights", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mlx_model", | ||||
|         type=str, | ||||
|         default="mistral-7B-v0.1/mlx_mistral_7b.npz", | ||||
|         help="The path to store the mlx model weights", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     state = torch.load(args.torch_model) | ||||
|     np.savez( | ||||
|         args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()} | ||||
|     ) | ||||
							
								
								
									
										281
									
								
								mistral/mistral.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										281
									
								
								mistral/mistral.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,281 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| import json | ||||
| from pathlib import Path | ||||
| from typing import Optional, Tuple, List | ||||
| from sentencepiece import SentencePieceProcessor | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_map, tree_unflatten | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ModelArgs: | ||||
|     dim: int | ||||
|     n_layers: int | ||||
|     head_dim: int | ||||
|     hidden_dim: int | ||||
|     n_heads: int | ||||
|     n_kv_heads: int | ||||
|     norm_eps: float | ||||
|     vocab_size: int | ||||
|  | ||||
|  | ||||
| class RMSNorm(nn.Module): | ||||
|     def __init__(self, dims: int, eps: float = 1e-5): | ||||
|         super().__init__() | ||||
|         self.weight = mx.ones((dims,)) | ||||
|         self.eps = eps | ||||
|  | ||||
|     def _norm(self, x): | ||||
|         return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         output = self._norm(x.astype(mx.float32)).astype(x.dtype) | ||||
|         return self.weight * output | ||||
|  | ||||
|  | ||||
| class Attention(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.args = args | ||||
|  | ||||
|         self.n_heads: int = args.n_heads | ||||
|         self.n_kv_heads: int = args.n_kv_heads | ||||
|  | ||||
|         self.repeats = self.n_heads // self.n_kv_heads | ||||
|  | ||||
|         self.scale = self.args.head_dim**-0.5 | ||||
|  | ||||
|         self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) | ||||
|         self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) | ||||
|         self.rope = nn.RoPE(args.head_dim, traditional=True) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x: mx.array, | ||||
|         mask: Optional[mx.array] = None, | ||||
|         cache: Optional[Tuple[mx.array, mx.array]] = None, | ||||
|     ) -> mx.array: | ||||
|         B, L, D = x.shape | ||||
|  | ||||
|         queries, keys, values = self.wq(x), self.wk(x), self.wv(x) | ||||
|  | ||||
|         # Prepare the queries, keys and values for the attention computation | ||||
|         queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||||
|         keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||||
|         values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||||
|  | ||||
|         def repeat(a): | ||||
|             a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) | ||||
|             return a.reshape([B, self.n_heads, L, -1]) | ||||
|  | ||||
|         keys, values = map(repeat, (keys, values)) | ||||
|  | ||||
|         if cache is not None: | ||||
|             key_cache, value_cache = cache | ||||
|             queries = self.rope(queries, offset=key_cache.shape[2]) | ||||
|             keys = self.rope(keys, offset=key_cache.shape[2]) | ||||
|             keys = mx.concatenate([key_cache, keys], axis=2) | ||||
|             values = mx.concatenate([value_cache, values], axis=2) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) | ||||
|         if mask is not None: | ||||
|             scores += mask | ||||
|         scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) | ||||
|         output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) | ||||
|  | ||||
|         #        queries = queries.reshape(B, self.n_kv_heads, self.repeats, L, -1) | ||||
|         #        scores = (queries * self.scale) @ mx.expand_dims(keys.transpose(0, 1, 3, 2), 2) | ||||
|         #        if mask is not None: | ||||
|         #            scores += mask | ||||
|         #        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) | ||||
|         #        output = (scores @ mx.expand_dims(values, 2)).reshape(B, self.n_heads, L, -1) | ||||
|         #        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||||
|         return self.wo(output), (keys, values) | ||||
|  | ||||
|  | ||||
| class FeedForward(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) | ||||
|         self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) | ||||
|         self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) | ||||
|  | ||||
|     def __call__(self, x) -> mx.array: | ||||
|         return self.w2(nn.silu(self.w1(x)) * self.w3(x)) | ||||
|  | ||||
|  | ||||
| class TransformerBlock(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.n_heads = args.n_heads | ||||
|         self.dim = args.dim | ||||
|         self.attention = Attention(args) | ||||
|         self.feed_forward = FeedForward(args=args) | ||||
|         self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||||
|         self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||||
|         self.args = args | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x: mx.array, | ||||
|         mask: Optional[mx.array] = None, | ||||
|         cache: Optional[Tuple[mx.array, mx.array]] = None, | ||||
|     ) -> mx.array: | ||||
|         r, cache = self.attention(self.attention_norm(x), mask, cache) | ||||
|         h = x + r | ||||
|         r = self.feed_forward(self.ffn_norm(h)) | ||||
|         out = h + r | ||||
|         return out, cache | ||||
|  | ||||
|  | ||||
| class Mistral(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.args = args | ||||
|         self.vocab_size = args.vocab_size | ||||
|         self.n_layers = args.n_layers | ||||
|         assert self.vocab_size > 0 | ||||
|  | ||||
|         self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) | ||||
|  | ||||
|         self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] | ||||
|  | ||||
|         self.norm = RMSNorm(args.dim, eps=args.norm_eps) | ||||
|  | ||||
|         self.output = nn.Linear(args.dim, args.vocab_size, bias=False) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         inputs: mx.array, | ||||
|         cache=None, | ||||
|     ): | ||||
|         h = self.tok_embeddings(inputs) | ||||
|  | ||||
|         mask = None | ||||
|         if h.shape[1] > 1: | ||||
|             mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | ||||
|             mask = mask.astype(h.dtype) | ||||
|  | ||||
|         if cache is None: | ||||
|             cache = [None] * len(self.layers) | ||||
|  | ||||
|         for e, layer in enumerate(self.layers): | ||||
|             h, cache[e] = layer(h, mask, cache[e]) | ||||
|  | ||||
|         return self.output(self.norm(h)), cache | ||||
|  | ||||
|  | ||||
| class Tokenizer: | ||||
|     def __init__(self, model_path: str): | ||||
|         assert Path(model_path).exists(), model_path | ||||
|         self._model = SentencePieceProcessor(model_file=model_path) | ||||
|         self._sep = "▁" | ||||
|         assert self._model.vocab_size() == self._model.get_piece_size() | ||||
|  | ||||
|     @property | ||||
|     def eos_id(self) -> int: | ||||
|         return self._model.eos_id() | ||||
|  | ||||
|     @property | ||||
|     def pad_id(self) -> int: | ||||
|         return self._model.pad_id() | ||||
|  | ||||
|     def encode(self, s: str) -> List[int]: | ||||
|         return [self._model.bos_id(), *self._model.encode(s)] | ||||
|  | ||||
|     def decode(self, t: List[int]) -> str: | ||||
|         out = self._model.decode(t) | ||||
|         if t and self._model.id_to_piece(t[0])[0] == self._sep: | ||||
|             return " " + out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| def load_model(folder: str, dtype=mx.float16): | ||||
|     model_path = Path(folder) | ||||
|     tokenizer = Tokenizer(str(model_path / "tokenizer.model")) | ||||
|     with open(model_path / "params.json", "r") as f: | ||||
|         config = json.loads(f.read()) | ||||
|         config.pop("sliding_window") | ||||
|         model_args = ModelArgs(**config) | ||||
|     weights = mx.load(str(model_path / "mlx_mistral_7b.npz")) | ||||
|     weights = tree_unflatten(list(weights.items())) | ||||
|     weights = tree_map(lambda p: p.astype(dtype), weights) | ||||
|     model = Mistral(model_args) | ||||
|     model.update(weights) | ||||
|     return model, tokenizer | ||||
|  | ||||
|  | ||||
| def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0): | ||||
|     def sample(logits): | ||||
|         if temp == 0: | ||||
|             return mx.argmax(logits, axis=-1) | ||||
|         else: | ||||
|             return mx.random.categorical(logits * (1 / temp)) | ||||
|  | ||||
|     logits, cache = model(prompt[None]) | ||||
|     y = sample(logits[:, -1, :]) | ||||
|     yield y | ||||
|  | ||||
|     while True: | ||||
|         logits, cache = model(y[:, None], cache) | ||||
|         y = sample(logits.squeeze(1)) | ||||
|         yield y | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Mistral inference script") | ||||
|     parser.add_argument( | ||||
|         "--model_path", | ||||
|         type=str, | ||||
|         default="mistral-7B-v0.1", | ||||
|         help="The path to the model weights and tokenizer", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prompt", | ||||
|         help="The message to be processed by the model", | ||||
|         default="In the beginning the Universe was created.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_tokens", | ||||
|         "-m", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="Maximum number of tokens to generate", | ||||
|     ) | ||||
|     parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     mx.random.seed(args.seed) | ||||
|     print("[INFO] Loading model from disk.") | ||||
|     model, tokenizer = load_model(args.model_path) | ||||
|  | ||||
|     print("[INFO] Starting generation...") | ||||
|  | ||||
|     print(args.prompt, end="", flush=True) | ||||
|     prompt = mx.array(tokenizer.encode(args.prompt)) | ||||
|     tokens = [] | ||||
|     for token, _ in zip(generate(prompt, model), range(args.max_tokens)): | ||||
|         tokens.append(token) | ||||
|  | ||||
|         if (len(tokens) % 10) == 0: | ||||
|             mx.eval(tokens) | ||||
|             s = tokenizer.decode([t.item() for t in tokens]) | ||||
|             print(s, end="", flush=True) | ||||
|             tokens = [] | ||||
|  | ||||
|     mx.eval(tokens) | ||||
|     s = tokenizer.decode([t.item() for t in tokens]) | ||||
|     print(s, flush=True) | ||||
|     print("------") | ||||
							
								
								
									
										3
									
								
								mistral/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								mistral/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| mlx | ||||
| sentencepiece | ||||
| torch | ||||
							
								
								
									
										118
									
								
								mistral/test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								mistral/test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,118 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
|  | ||||
| import mistral | ||||
|  | ||||
|  | ||||
| class TestMistral(unittest.TestCase): | ||||
|     def test_model(self): | ||||
|         vocab_size = 100 | ||||
|         L = 32 | ||||
|         args = mistral.ModelArgs( | ||||
|             dim=128, | ||||
|             n_layers=2, | ||||
|             head_dim=32, | ||||
|             hidden_dim=256, | ||||
|             n_heads=4, | ||||
|             n_kv_heads=4, | ||||
|             norm_eps=1e-3, | ||||
|             vocab_size=vocab_size, | ||||
|         ) | ||||
|  | ||||
|         model = mistral.Mistral(args) | ||||
|         inputs = mx.random.randint(0, vocab_size, (L,)) | ||||
|         logits, cache = model(inputs[None]) | ||||
|         self.assertEqual(logits.shape, [1, L, vocab_size]) | ||||
|         self.assertEqual(logits.dtype, mx.float32) | ||||
|         self.assertEqual(len(cache), args.n_layers) | ||||
|  | ||||
|         params = tree_map(lambda p: p.astype(mx.float16), model.parameters()) | ||||
|         model.update(params) | ||||
|         logits, _ = model(inputs[None]) | ||||
|         self.assertEqual(logits.dtype, mx.float16) | ||||
|  | ||||
|     def test_generate(self): | ||||
|         model, tokenizer = mistral.load_model("mistral-7B-v0.1") | ||||
|         prompt = mx.array(tokenizer.encode("This is a test")) | ||||
|         tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))] | ||||
|         mx.eval(tokens) | ||||
|         tokens = [t.item() for t in tokens] | ||||
|         expected = [ | ||||
|             302, | ||||
|             272, | ||||
|             11843, | ||||
|             11837, | ||||
|             1587, | ||||
|             28723, | ||||
|             851, | ||||
|             349, | ||||
|             865, | ||||
|             264, | ||||
|             1369, | ||||
|             28723, | ||||
|             13, | ||||
|             13, | ||||
|             3381, | ||||
|             456, | ||||
|             654, | ||||
|             264, | ||||
|             1353, | ||||
|             11843, | ||||
|             28725, | ||||
|             368, | ||||
|             682, | ||||
|             347, | ||||
|             2240, | ||||
|             767, | ||||
|             298, | ||||
|             511, | ||||
|             28723, | ||||
|             13, | ||||
|         ] | ||||
|         self.assertEqual(tokens, expected) | ||||
|  | ||||
|     def benchmark(self): | ||||
|         import time | ||||
|  | ||||
|         model, tokenizer = mistral.load_model("mistral-7B-v0.1") | ||||
|         prompt = mx.random.randint(0, model.vocab_size, (128,)) | ||||
|  | ||||
|         # warmup | ||||
|         for _ in range(2): | ||||
|             generator = mistral.generate(prompt, model) | ||||
|             mx.eval(next(generator)) | ||||
|  | ||||
|         tic = time.time() | ||||
|         its = 5 | ||||
|         for _ in range(its): | ||||
|             generator = mistral.generate(prompt, model) | ||||
|             mx.eval(next(generator)) | ||||
|         toc = time.time() | ||||
|         tps = its * prompt.size / (toc - tic) | ||||
|         print(f"Prompt processing: {tps:.2f} tokens per second") | ||||
|  | ||||
|         # warmup | ||||
|         for _ in range(2): | ||||
|             tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))] | ||||
|             mx.eval(tokens) | ||||
|  | ||||
|         time_total = 0.0 | ||||
|         its = 2 | ||||
|         for _ in range(its): | ||||
|             generator = mistral.generate(prompt, model) | ||||
|             mx.eval(next(generator)) | ||||
|             tic = time.time() | ||||
|             tokens = [t for t, _ in zip(generator, range(100))] | ||||
|             mx.eval(tokens) | ||||
|             time_total += time.time() - tic | ||||
|  | ||||
|         tps = len(tokens) * its / time_total | ||||
|         print(f"Token generation: {tps:.3f} tokens per second") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun