mlx-examples/llms/mixtral/mixtral.py
Param Thakkar 4c9f9f9be7
Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-04-23 14:23:46 -07:00

299 lines
9.3 KiB
Python

# Copyright © 2023 Apple Inc.
import argparse
import glob
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True, base=1000000)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class MOETransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = MOEFeedForward(args=args)
self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Mixtral(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h[:, T - 1 : T, :])), cache
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
@property
def eos_id(self) -> int:
return self._model.eos_id()
@property
def pad_id(self) -> int:
return self._model.pad_id()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
model = Mixtral(model_args)
if quantization is not None:
nn.quantize(model, **quantization)
model.update(weights)
return model, tokenizer
def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mixtral inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_id), None
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)