diff --git a/mixtral/README.md b/mixtral/README.md new file mode 100644 index 00000000..6891f7b9 --- /dev/null +++ b/mixtral/README.md @@ -0,0 +1,52 @@ +## Mixtral 8x7b + +Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. + +Note, this model needs a machine with substantial RAM (>= 128GB) to run in +16-bit precision. + +### Setup + +Install [Git Large File +Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage). +For example with Homebrew: + +``` +brew install git-lfs +``` + +Download the models from HugginFace: + +``` +git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen +``` + +After that's done, combine the files: +``` +cd mixtral-8x7b-32kseqlen/ +cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +``` + +Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them: + +``` +python convert.py --model_path mixtral-8x7b-32kseqlen/ +``` + +The conversion script will save the new weights in the same location. + +After that's done, if you want to clean some stuff up: + +``` +rm mixtral-8x7b-32kseqlen/*.pth +``` + +### Generate + +As easy as: + +``` +python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/ +``` + +[^mixtral] Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. diff --git a/mixtral/convert.py b/mixtral/convert.py new file mode 100644 index 00000000..a1a423d0 --- /dev/null +++ b/mixtral/convert.py @@ -0,0 +1,23 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import numpy as np +from pathlib import Path +import torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") + parser.add_argument( + "--model_path", + type=str, + default="mixtral-8x7b-32kseqlen/", + help="The path to the Mixtral model. The MLX model weights will also be saved there.", + ) + args = parser.parse_args() + model_path = Path(args.model_path) + state = torch.load(str(model_path / "consolidated.00.pt")) + np.savez( + str(model_path / "weights.npz"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()}, + ) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py new file mode 100644 index 00000000..a96d5f7d --- /dev/null +++ b/mixtral/mixtral.py @@ -0,0 +1,328 @@ +# Copyright © 2023 Apple Inc. + +import argparse +from dataclasses import dataclass +import json +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, List +from sentencepiece import SentencePieceProcessor + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map, tree_unflatten + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + moe: dict = None + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class MOEFeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.num_experts = args.moe["num_experts"] + self.num_experts_per_tok = args.moe["num_experts_per_tok"] + self.experts = [FeedForward(args) for _ in range(self.num_experts)] + self.gate = nn.Linear(args.dim, self.num_experts, bias=False) + + def __call__(self, x) -> mx.array: + + ne = self.num_experts_per_tok + orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + + gates = self.gate(x) + inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) + + # For batch: + if x.shape[0] > 1: + mx.eval(inds) + inds = np.array(inds) + y = mx.zeros((x.shape[0], ne, x.shape[-1])) + for e, expert in enumerate(self.experts): + idx1, idx2 = map(mx.array, np.where(inds == e)) + if idx1.size == 0: + continue + y[idx1, idx2] = expert(x[idx1]) + y = (y * scores[:, :, None]).sum(axis=1) + + # For single: + else: + ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()] + y = mx.concatenate(ys, axis=-1) + y = (y * scores[:, None, 0]).sum(axis=-1) + + return y.reshape(orig_shape) + + +class MOETransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = MOEFeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Mixtral(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.tok_embeddings(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.output(self.norm(h)), cache + + +class Tokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + self._sep = "▁" + assert self._model.vocab_size() == self._model.get_piece_size() + + @property + def eos_id(self) -> int: + return self._model.eos_id() + + @property + def pad_id(self) -> int: + return self._model.pad_id() + + def encode(self, s: str) -> List[int]: + return [self._model.bos_id(), *self._model.encode(s)] + + def decode(self, t: List[int]) -> str: + out = self._model.decode(t) + if t and self._model.id_to_piece(t[0])[0] == self._sep: + return " " + out + return out + + +def load_model(folder: str, dtype=mx.float16): + model_path = Path(folder) + tokenizer = Tokenizer(str(model_path / "tokenizer.model")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + model_args = ModelArgs(**config) + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model = Mixtral(model_args) + model.update(weights) + return model, tokenizer + + +def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt[None]) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache) + y = sample(logits.squeeze(1)) + yield y + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mixtral inference script") + parser.add_argument( + "--model_path", + type=str, + default="mixtral-8x7b-32kseqlen", + help="The path to the model weights, tokenizer, and config", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=1.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + + mx.random.seed(args.seed) + print("[INFO] Loading model from disk.") + model, tokenizer = load_model(args.model_path) + + print("[INFO] Starting generation...") + + print(args.prompt, end="", flush=True) + prompt = mx.array(tokenizer.encode(args.prompt)) + mx.eval(prompt) + tokens = [] + import time + tic = time.time() + p = True + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + if p: + mx.eval(tokens) + p = False + prompt_time = time.time() - tic + tic = time.time() + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + + mx.eval(tokens) + tok_time = time.time() - tic + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) + print("------") + print(f"Prompt time {prompt_time}") + print(f"Token time {tok_time}") diff --git a/mixtral/requirements.txt b/mixtral/requirements.txt new file mode 100644 index 00000000..d775b88f --- /dev/null +++ b/mixtral/requirements.txt @@ -0,0 +1,4 @@ +mlx +sentencepiece +torch +numpy