diff --git a/mistral/README.md b/mistral/README.md index 1bbb385d..c2406b6d 100644 --- a/mistral/README.md +++ b/mistral/README.md @@ -2,7 +2,8 @@ An example of generating text with Mistral using MLX. -Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1]. +Mistral 7B is one of the top large language models in its size class. It is +also fully open source with a permissive license[^1]. ### Setup @@ -25,6 +26,8 @@ Then, convert the weights with: python convert.py ``` +The conversion script will save the converted weights in the same location. + ### Run Once you've converted the weights to MLX format, you can generate text with @@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 Run `python mistral.py --help` for more details. -[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. +[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) +and [github repository](https://github.com/mistralai/mistral-src) for more +details. diff --git a/mistral/convert.py b/mistral/convert.py index e170aed7..0efaf489 100644 --- a/mistral/convert.py +++ b/mistral/convert.py @@ -2,26 +2,23 @@ import argparse import numpy as np +from pathlib import Path import torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser.add_argument( - "--torch_model", + "--model_path", type=str, - default="mistral-7B-v0.1/consolidated.00.pth", - help="The path to the torch model weights", - ) - parser.add_argument( - "--mlx_model", - type=str, - default="mistral-7B-v0.1/mlx_mistral_7b.npz", - help="The path to store the mlx model weights", + default="mistral-7B-v0.1/", + help="The path to the Mistral model. The MLX weights will also be saved there.", ) args = parser.parse_args() - state = torch.load(args.torch_model) + model_path = Path(args.model_path) + state = torch.load(str(model_path / "consolidated.00.pth")) np.savez( - args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()} + str(model_path / "weights.npz"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()} ) diff --git a/mistral/mistral.py b/mistral/mistral.py index 767b5936..0c3976c1 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16): config = json.loads(f.read()) config.pop("sliding_window") model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "mlx_mistral_7b.npz")) + weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mistral(model_args) diff --git a/mixtral/README.md b/mixtral/README.md new file mode 100644 index 00000000..811fcd06 --- /dev/null +++ b/mixtral/README.md @@ -0,0 +1,52 @@ +## Mixtral 8x7B + +Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. + +Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. + +### Setup + +Install [Git Large File +Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage). +For example with Homebrew: + +``` +brew install git-lfs +``` + +Download the models from HuggingFace: + +``` +git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen +``` + +After that's done, combine the files: +``` +cd mixtral-8x7b-32kseqlen/ +cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +``` + +Now from `mlx-exmaples/mixtral` conver and save the weights as NumPy arrays so +MLX can read them: + +``` +python convert.py --model_path mixtral-8x7b-32kseqlen/ +``` + +The conversion script will save the converted weights in the same location. + +After that's done, if you want to clean some stuff up: + +``` +rm mixtral-8x7b-32kseqlen/*.pth* +``` + +### Generate + +As easy as: + +``` +python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/ +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. diff --git a/mixtral/convert.py b/mixtral/convert.py new file mode 100644 index 00000000..a1a423d0 --- /dev/null +++ b/mixtral/convert.py @@ -0,0 +1,23 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import numpy as np +from pathlib import Path +import torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") + parser.add_argument( + "--model_path", + type=str, + default="mixtral-8x7b-32kseqlen/", + help="The path to the Mixtral model. The MLX model weights will also be saved there.", + ) + args = parser.parse_args() + model_path = Path(args.model_path) + state = torch.load(str(model_path / "consolidated.00.pt")) + np.savez( + str(model_path / "weights.npz"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()}, + ) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py new file mode 100644 index 00000000..1a9be600 --- /dev/null +++ b/mixtral/mixtral.py @@ -0,0 +1,303 @@ +# Copyright © 2023 Apple Inc. + +import argparse +from dataclasses import dataclass +import json +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, List +from sentencepiece import SentencePieceProcessor + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map, tree_unflatten + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + moe: dict = None + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class MOEFeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.num_experts = args.moe["num_experts"] + self.num_experts_per_tok = args.moe["num_experts_per_tok"] + self.experts = [FeedForward(args) for _ in range(self.num_experts)] + self.gate = nn.Linear(args.dim, self.num_experts, bias=False) + + def __call__(self, x) -> mx.array: + ne = self.num_experts_per_tok + orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + + gates = self.gate(x) + inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) + + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + y = mx.concatenate(y) + + return y.reshape(orig_shape) + + +class MOETransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = MOEFeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Mixtral(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.tok_embeddings(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.output(self.norm(h)), cache + + +class Tokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + self._sep = "▁" + assert self._model.vocab_size() == self._model.get_piece_size() + + @property + def eos_id(self) -> int: + return self._model.eos_id() + + @property + def pad_id(self) -> int: + return self._model.pad_id() + + def encode(self, s: str) -> List[int]: + return [self._model.bos_id(), *self._model.encode(s)] + + def decode(self, t: List[int]) -> str: + out = self._model.decode(t) + if t and self._model.id_to_piece(t[0])[0] == self._sep: + return " " + out + return out + + +def load_model(folder: str, dtype=mx.float16): + model_path = Path(folder) + tokenizer = Tokenizer(str(model_path / "tokenizer.model")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + model_args = ModelArgs(**config) + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model = Mixtral(model_args) + model.update(weights) + return model, tokenizer + + +def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt[None]) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache) + y = sample(logits.squeeze(1)) + yield y + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mixtral inference script") + parser.add_argument( + "--model_path", + type=str, + default="mixtral-8x7b-32kseqlen", + help="The path to the model weights, tokenizer, and config", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=1.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + + mx.random.seed(args.seed) + print("[INFO] Loading model from disk.") + model, tokenizer = load_model(args.model_path) + + print("[INFO] Starting generation...") + + print(args.prompt, end="", flush=True) + prompt = mx.array(tokenizer.encode(args.prompt)) + tokens = [] + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) diff --git a/mixtral/requirements.txt b/mixtral/requirements.txt new file mode 100644 index 00000000..d775b88f --- /dev/null +++ b/mixtral/requirements.txt @@ -0,0 +1,4 @@ +mlx +sentencepiece +torch +numpy