mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

- bert/model.py:10: tree_unflatten - bert/model.py:2: dataclass - bert/model.py:8: numpy - cifar/resnet.py:6: Any - clip/model.py:15: tree_flatten - clip/model.py:9: Union - gcn/main.py:8: download_cora - gcn/main.py:9: cross_entropy - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten - llms/gguf_llm/models.py:9: numpy - llms/mixtral/mixtral.py:12: tree_map - llms/mlx_lm/models/dbrx.py:2: Dict, Union - llms/mlx_lm/tuner/trainer.py:5: partial - llms/speculative_decoding/decoder.py:1: dataclass, field - llms/speculative_decoding/decoder.py:2: Optional - llms/speculative_decoding/decoder.py:5: mlx.nn - llms/speculative_decoding/decoder.py:6: numpy - llms/speculative_decoding/main.py:2: glob - llms/speculative_decoding/main.py:3: json - llms/speculative_decoding/main.py:5: Path - llms/speculative_decoding/main.py:8: mlx.nn - llms/speculative_decoding/model.py:6: tree_unflatten - llms/speculative_decoding/model.py:7: AutoTokenizer - llms/tests/test_lora.py:13: yaml_loader - lora/lora.py:14: tree_unflatten - lora/models.py:11: numpy - lora/models.py:3: glob - speechcommands/kwt.py:1: Any - speechcommands/main.py:7: mlx.data - stable_diffusion/stable_diffusion/model_io.py:4: partial - whisper/benchmark.py:5: sys - whisper/test.py:5: subprocess - whisper/whisper/audio.py:6: Optional - whisper/whisper/decoding.py:8: mlx.nn
305 lines
9.5 KiB
Python
305 lines
9.5 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import argparse
|
|
import glob
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_unflatten
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs:
|
|
dim: int
|
|
n_layers: int
|
|
head_dim: int
|
|
hidden_dim: int
|
|
n_heads: int
|
|
n_kv_heads: int
|
|
norm_eps: float
|
|
vocab_size: int
|
|
moe: dict = None
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
self.n_heads: int = args.n_heads
|
|
self.n_kv_heads: int = args.n_kv_heads
|
|
|
|
self.repeats = self.n_heads // self.n_kv_heads
|
|
|
|
self.scale = self.args.head_dim**-0.5
|
|
|
|
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
|
self.rope = nn.RoPE(args.head_dim, traditional=True, base=1000000)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
key_cache, value_cache = cache
|
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
values = mx.concatenate([value_cache, values], axis=2)
|
|
else:
|
|
queries = self.rope(queries)
|
|
keys = self.rope(keys)
|
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.wo(output), (keys, values)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
|
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
|
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class MOEFeedForward(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
self.num_experts = args.moe["num_experts"]
|
|
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
|
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
|
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
ne = self.num_experts_per_tok
|
|
orig_shape = x.shape
|
|
x = x.reshape(-1, x.shape[-1])
|
|
|
|
gates = self.gate(x)
|
|
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
|
|
scores = mx.softmax(
|
|
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
|
axis=-1,
|
|
).astype(gates.dtype)
|
|
|
|
y = []
|
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
|
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
|
yt = (yt * st).sum(axis=-1)
|
|
y.append(yt[None, :])
|
|
y = mx.concatenate(y)
|
|
|
|
return y.reshape(orig_shape)
|
|
|
|
|
|
class MOETransformerBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.n_heads = args.n_heads
|
|
self.dim = args.dim
|
|
self.attention = Attention(args)
|
|
self.feed_forward = MOEFeedForward(args=args)
|
|
self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.args = args
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
|
h = x + r
|
|
r = self.feed_forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out, cache
|
|
|
|
|
|
class Mixtral(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.vocab_size = args.vocab_size
|
|
self.n_layers = args.n_layers
|
|
assert self.vocab_size > 0
|
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
|
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
|
|
self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
h = self.tok_embeddings(inputs)
|
|
|
|
mask = None
|
|
T = h.shape[1]
|
|
if T > 1:
|
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
|
mask = mask.astype(h.dtype)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for e, layer in enumerate(self.layers):
|
|
h, cache[e] = layer(h, mask, cache[e])
|
|
|
|
return self.output(self.norm(h[:, T - 1 : T, :])), cache
|
|
|
|
|
|
class Tokenizer:
|
|
def __init__(self, model_path: str):
|
|
assert Path(model_path).exists(), model_path
|
|
self._model = SentencePieceProcessor(model_file=model_path)
|
|
self._sep = "▁"
|
|
assert self._model.vocab_size() == self._model.get_piece_size()
|
|
|
|
@property
|
|
def eos_id(self) -> int:
|
|
return self._model.eos_id()
|
|
|
|
@property
|
|
def pad_id(self) -> int:
|
|
return self._model.pad_id()
|
|
|
|
def encode(self, s: str) -> List[int]:
|
|
return [self._model.bos_id(), *self._model.encode(s)]
|
|
|
|
def decode(self, t: List[int]) -> str:
|
|
out = self._model.decode(t)
|
|
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
|
return " " + out
|
|
return out
|
|
|
|
|
|
def load_model(folder: str):
|
|
model_path = Path(folder)
|
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
|
with open(model_path / "config.json", "r") as f:
|
|
config = json.loads(f.read())
|
|
config.pop("model_type", None)
|
|
quantization = config.pop("quantization", None)
|
|
model_args = ModelArgs(**config)
|
|
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf).items())
|
|
weights = tree_unflatten(list(weights.items()))
|
|
model = Mixtral(model_args)
|
|
if quantization is not None:
|
|
# TODO: Quantize gate matrices when < 32 tiles supported
|
|
quantization["linear_class_predicate"] = (
|
|
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
|
|
)
|
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
|
|
model.update(weights)
|
|
return model, tokenizer
|
|
|
|
|
|
def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0):
|
|
def sample(logits):
|
|
if temp == 0:
|
|
return mx.argmax(logits, axis=-1)
|
|
else:
|
|
return mx.random.categorical(logits * (1 / temp))
|
|
|
|
logits, cache = model(prompt[None])
|
|
y = sample(logits[:, -1, :])
|
|
yield y
|
|
|
|
while True:
|
|
logits, cache = model(y[:, None], cache)
|
|
y = sample(logits.squeeze(1))
|
|
yield y
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Mixtral inference script")
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
default="mlx_model",
|
|
help="The path to the model weights, tokenizer, and config",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
help="The message to be processed by the model",
|
|
default="In the beginning the Universe was created.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
"-m",
|
|
type=int,
|
|
default=100,
|
|
help="Maximum number of tokens to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--temp",
|
|
help="The sampling temperature.",
|
|
type=float,
|
|
default=0.0,
|
|
)
|
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
|
|
|
args = parser.parse_args()
|
|
|
|
mx.random.seed(args.seed)
|
|
print("[INFO] Loading model from disk.")
|
|
model, tokenizer = load_model(args.model_path)
|
|
|
|
print("[INFO] Starting generation...")
|
|
|
|
print(args.prompt, end="", flush=True)
|
|
prompt = mx.array(tokenizer.encode(args.prompt))
|
|
tokens = []
|
|
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
|
tokens.append(token)
|
|
|
|
if (len(tokens) % 10) == 0:
|
|
mx.eval(tokens)
|
|
eos_index = next(
|
|
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_id), None
|
|
)
|
|
if eos_index is not None:
|
|
tokens = tokens[:eos_index]
|
|
s = tokenizer.decode([t.item() for t in tokens])
|
|
print(s, end="", flush=True)
|
|
tokens = []
|
|
if eos_index is not None:
|
|
break
|
|
|
|
mx.eval(tokens)
|
|
s = tokenizer.decode([t.item() for t in tokens])
|
|
print(s, flush=True)
|