mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
initial mixtral
This commit is contained in:
parent
13f1142eaa
commit
e42682dced
52
mixtral/README.md
Normal file
52
mixtral/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
## Mixtral 8x7b
|
||||
|
||||
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
|
||||
|
||||
Note, this model needs a machine with substantial RAM (>= 128GB) to run in
|
||||
16-bit precision.
|
||||
|
||||
### Setup
|
||||
|
||||
Install [Git Large File
|
||||
Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage).
|
||||
For example with Homebrew:
|
||||
|
||||
```
|
||||
brew install git-lfs
|
||||
```
|
||||
|
||||
Download the models from HugginFace:
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
|
||||
```
|
||||
|
||||
After that's done, combine the files:
|
||||
```
|
||||
cd mixtral-8x7b-32kseqlen/
|
||||
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
|
||||
```
|
||||
|
||||
Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them:
|
||||
|
||||
```
|
||||
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
||||
```
|
||||
|
||||
The conversion script will save the new weights in the same location.
|
||||
|
||||
After that's done, if you want to clean some stuff up:
|
||||
|
||||
```
|
||||
rm mixtral-8x7b-32kseqlen/*.pth
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
As easy as:
|
||||
|
||||
```
|
||||
python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/
|
||||
```
|
||||
|
||||
[^mixtral] Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
|
23
mixtral/convert.py
Normal file
23
mixtral/convert.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mixtral-8x7b-32kseqlen/",
|
||||
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pt"))
|
||||
np.savez(
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
|
||||
)
|
328
mixtral/mixtral.py
Normal file
328
mixtral/mixtral.py
Normal file
@ -0,0 +1,328 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
||||
|
||||
# For batch:
|
||||
if x.shape[0] > 1:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
|
||||
# For single:
|
||||
else:
|
||||
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
|
||||
y = mx.concatenate(ys, axis=-1)
|
||||
y = (y * scores[:, None, 0]).sum(axis=-1)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class MOETransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MOEFeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Mixtral(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h)), cache
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._model.pad_id()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mixtral(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mixtral inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mixtral-8x7b-32kseqlen",
|
||||
help="The path to the model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
mx.eval(prompt)
|
||||
tokens = []
|
||||
import time
|
||||
tic = time.time()
|
||||
p = True
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
if p:
|
||||
mx.eval(tokens)
|
||||
p = False
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
tok_time = time.time() - tic
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
||||
print(f"Prompt time {prompt_time}")
|
||||
print(f"Token time {tok_time}")
|
4
mixtral/requirements.txt
Normal file
4
mixtral/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
Loading…
Reference in New Issue
Block a user