mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
57
llms/llama/README.md
Normal file
57
llms/llama/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Llama
|
||||
|
||||
An example of generating text with Llama (1 or 2) using MLX.
|
||||
|
||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
|
||||
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to request access from Meta:
|
||||
|
||||
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
You can download the TinyLlama models directly from [Hugging
|
||||
Face](https://huggingface.co/TinyLlama).
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model>
|
||||
```
|
||||
|
||||
For TinyLlama use
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can interact with the
|
||||
LlaMA model:
|
||||
|
||||
```
|
||||
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
||||
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
||||
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)
|
139
llms/llama/convert.py
Normal file
139
llms/llama/convert.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
|
||||
def shard_key(k):
|
||||
keys = k.split(".")
|
||||
if len(keys) < 2:
|
||||
return None
|
||||
return keys[-2]
|
||||
|
||||
def unshard(k, v):
|
||||
wn = shard_key(k)
|
||||
if wn not in SHARD_WEIGHTS:
|
||||
return v
|
||||
elif wn in SHARD_FIRST:
|
||||
axis = 0
|
||||
elif wn in SHARD_SECOND:
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||
weights = collections.defaultdict(list)
|
||||
for wf in torch_files:
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = v.to(torch.float16).numpy()
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
weights[k] = v
|
||||
|
||||
for k, v in weights.items():
|
||||
weights[k] = unshard(k, v)
|
||||
return weights, None
|
||||
|
||||
|
||||
def tiny_llama(model_path):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError as e:
|
||||
print("The transformers package must be installed for this model conversion:")
|
||||
print("pip install transformers")
|
||||
import sys
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
str(model_path)
|
||||
).state_dict()
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# things to change
|
||||
# 1. there's no "model." in the weight names
|
||||
model = {k.replace("model.", ""): v for k, v in model.items()}
|
||||
|
||||
# 2. mlp is called feed_forward
|
||||
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
|
||||
|
||||
# 3. up_proj, down_proj, gate_proj
|
||||
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
|
||||
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
|
||||
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
|
||||
|
||||
# 4. layernorms
|
||||
model = {
|
||||
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
|
||||
}
|
||||
model = {
|
||||
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
|
||||
}
|
||||
|
||||
# 5. lm head
|
||||
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
|
||||
|
||||
# 6. token emb
|
||||
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
|
||||
|
||||
# 7. attention
|
||||
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
|
||||
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
|
||||
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
|
||||
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
|
||||
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
|
||||
|
||||
params = {}
|
||||
params["dim"] = config.hidden_size
|
||||
params["hidden_dim"] = config.intermediate_size
|
||||
params["n_heads"] = config.num_attention_heads
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
params["n_kv_heads"] = config.num_key_value_heads
|
||||
params["n_layers"] = config.num_hidden_layers
|
||||
params["vocab_size"] = config.vocab_size
|
||||
params["norm_eps"] = config.rms_norm_eps
|
||||
params["rope_traditional"] = False
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||
|
||||
return weights, params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
help="Path to the model. The MLX weights will also be saved there.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help=(
|
||||
"Name of the model to convert. Use 'llama' for models in the "
|
||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||
"Coda Llama, and Llama chat."
|
||||
),
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
if params is not None:
|
||||
with open(model_path / "params.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
390
llms/llama/llama.py
Normal file
390
llms/llama/llama.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
rope_traditional: bool = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class RoPE(nn.RoPE):
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
super().__init__(dims, traditional)
|
||||
self.base = base
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(
|
||||
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.tok_embeddings.weight.dtype)
|
||||
|
||||
x = self.tok_embeddings(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.output(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.tok_embeddings.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same was as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.tok_embeddings(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
# We store the per layer cache in a simple python list
|
||||
cache.append(c)
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.tok_embeddings(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
def tic():
|
||||
return time.time()
|
||||
|
||||
|
||||
def toc(msg, start):
|
||||
end = time.time()
|
||||
return f"[INFO] {msg}: {end - start:.3f} s"
|
||||
|
||||
|
||||
def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
print(args.prompt)
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
# It is perfectly ok to eval things we have already eval-ed.
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], flush=True)
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
|
||||
|
||||
def few_shot_generate(args):
|
||||
def possible_end(s):
|
||||
word = "[Instruction]"
|
||||
for i in range(len(word) - 1, 0, -1):
|
||||
if s[-i:] == word[:i]:
|
||||
return 0
|
||||
if s[-len(word) :] == word:
|
||||
return 1
|
||||
return -1
|
||||
|
||||
def generate(question):
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
token_list = [t.item() for t in tokens]
|
||||
s = tokenizer.decode(token_list)
|
||||
|
||||
end = possible_end(s)
|
||||
if end == 0:
|
||||
continue
|
||||
if end == 1:
|
||||
skip = len(s)
|
||||
break
|
||||
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
if token_list[-1] == tokenizer.eos_id():
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
|
||||
print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
|
||||
prompt = open(args.few_shot).read().strip()
|
||||
while True:
|
||||
question = input("Ask a question: ")
|
||||
generate(prompt.replace("{}", question))
|
||||
print()
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
model = Llama(ModelArgs(**config))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument(
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model. Ignored when --few-shot is provided.",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
generate(args)
|
3
llms/llama/requirements.txt
Normal file
3
llms/llama/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
23
llms/llama/sample_prompt.txt
Normal file
23
llms/llama/sample_prompt.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
[Instruction] Give the list of U.S. states bordering Canada
|
||||
[Answer] OK, here is the list of U.S. states located on the border with Canada:
|
||||
- Alaska
|
||||
- Michigan
|
||||
- Maine
|
||||
- Minnesota
|
||||
- Montana
|
||||
- New York
|
||||
- Washington
|
||||
- North Dakota
|
||||
- Ohio
|
||||
- Vermont
|
||||
- New Hampshire
|
||||
- Idaho
|
||||
- Pennsylvania
|
||||
[Instruction] Write a paragraph about "functional analysis"
|
||||
[Answer] OK, here is a paragraph on the topic of functional analysis:
|
||||
Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations.
|
||||
[Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business?
|
||||
[Answer] OK, here are two possible names for a new dog walking business:
|
||||
The first option is "Paws on Patrol", and the second option is "The Dog Whisperer".
|
||||
[Instruction] {}
|
||||
[Answer]
|
1
llms/mistral/.gitignore
vendored
Normal file
1
llms/mistral/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
mistral-7B-v0.1/
|
50
llms/mistral/README.md
Normal file
50
llms/mistral/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Mistral
|
||||
|
||||
An example of generating text with Mistral using MLX.
|
||||
|
||||
Mistral 7B is one of the top large language models in its size class. It is
|
||||
also fully open source with a permissive license[^1].
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download the model and tokenizer:
|
||||
|
||||
```
|
||||
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
||||
tar -xf mistral-7B-v0.1.tar
|
||||
```
|
||||
|
||||
Then, convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
> [!TIP]
|
||||
> Alternatively, you can also download a few converted checkpoints from the
|
||||
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
|
||||
> Face and skip the conversion step.
|
||||
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can generate text with
|
||||
the Mistral model:
|
||||
|
||||
```
|
||||
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
||||
```
|
||||
|
||||
Run `python mistral.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
|
||||
and [github repository](https://github.com/mistralai/mistral-src) for more
|
||||
details.
|
32
llms/mistral/convert.py
Normal file
32
llms/mistral/convert.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The path to the Mistral model. The MLX weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config["model_type"] = "mistral"
|
||||
with open(model_path / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=4)
|
282
llms/mistral/mistral.py
Normal file
282
llms/mistral/mistral.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Mistral(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h)), cache
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._model.pad_id()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("sliding_window", None)
|
||||
config.pop("model_type", None)
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mistral(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mistral inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1",
|
||||
help="The path to the model weights and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens_per_eval",
|
||||
help="The batch size of tokens to generate.",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % args.tokens_per_eval) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
4
llms/mistral/requirements.txt
Normal file
4
llms/mistral/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
117
llms/mistral/test.py
Normal file
117
llms/mistral/test.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mistral
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
def test_model(self):
|
||||
vocab_size = 100
|
||||
L = 32
|
||||
args = mistral.ModelArgs(
|
||||
dim=128,
|
||||
n_layers=2,
|
||||
head_dim=32,
|
||||
hidden_dim=256,
|
||||
n_heads=4,
|
||||
n_kv_heads=4,
|
||||
norm_eps=1e-3,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
model = mistral.Mistral(args)
|
||||
inputs = mx.random.randint(0, vocab_size, (L,))
|
||||
logits, cache = model(inputs[None])
|
||||
self.assertEqual(logits.shape, [1, L, vocab_size])
|
||||
self.assertEqual(logits.dtype, mx.float32)
|
||||
self.assertEqual(len(cache), args.n_layers)
|
||||
|
||||
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
|
||||
model.update(params)
|
||||
logits, _ = model(inputs[None])
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
def test_generate(self):
|
||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
||||
prompt = mx.array(tokenizer.encode("This is a test"))
|
||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
|
||||
mx.eval(tokens)
|
||||
tokens = [t.item() for t in tokens]
|
||||
expected = [
|
||||
302,
|
||||
272,
|
||||
11843,
|
||||
11837,
|
||||
1587,
|
||||
28723,
|
||||
851,
|
||||
349,
|
||||
865,
|
||||
264,
|
||||
1369,
|
||||
28723,
|
||||
13,
|
||||
13,
|
||||
3381,
|
||||
456,
|
||||
654,
|
||||
264,
|
||||
1353,
|
||||
11843,
|
||||
28725,
|
||||
368,
|
||||
682,
|
||||
347,
|
||||
2240,
|
||||
767,
|
||||
298,
|
||||
511,
|
||||
28723,
|
||||
13,
|
||||
]
|
||||
self.assertEqual(tokens, expected)
|
||||
|
||||
def benchmark(self):
|
||||
import time
|
||||
|
||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
||||
prompt = mx.random.randint(0, model.vocab_size, (128,))
|
||||
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
|
||||
tic = time.time()
|
||||
its = 5
|
||||
for _ in range(its):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
toc = time.time()
|
||||
tps = its * prompt.size / (toc - tic)
|
||||
print(f"Prompt processing: {tps:.2f} tokens per second")
|
||||
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
|
||||
mx.eval(tokens)
|
||||
|
||||
time_total = 0.0
|
||||
its = 2
|
||||
for _ in range(its):
|
||||
generator = mistral.generate(prompt, model)
|
||||
mx.eval(next(generator))
|
||||
tic = time.time()
|
||||
tokens = [t for t, _ in zip(generator, range(100))]
|
||||
mx.eval(tokens)
|
||||
time_total += time.time() - tic
|
||||
|
||||
tps = len(tokens) * its / time_total
|
||||
print(f"Token generation: {tps:.3f} tokens per second")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
73
llms/mixtral/README.md
Normal file
73
llms/mixtral/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
## Mixtral 8x7B
|
||||
|
||||
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
|
||||
|
||||
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
|
||||
|
||||
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
|
||||
|
||||
### Setup
|
||||
|
||||
Install [Git Large File
|
||||
Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage).
|
||||
For example with Homebrew:
|
||||
|
||||
```
|
||||
brew install git-lfs
|
||||
```
|
||||
|
||||
Download the models from Hugging Face:
|
||||
|
||||
For the base model use:
|
||||
|
||||
```
|
||||
export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
|
||||
```
|
||||
|
||||
For the instruction fine-tuned model use:
|
||||
|
||||
```
|
||||
export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
|
||||
```
|
||||
|
||||
Then run:
|
||||
|
||||
```
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
|
||||
cd $MIXTRAL_MODEL/ && \
|
||||
git lfs pull --include "consolidated.*.pt" && \
|
||||
git lfs pull --include "tokenizer.model"
|
||||
```
|
||||
|
||||
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
|
||||
MLX can read them:
|
||||
|
||||
```
|
||||
python convert.py --model_path $MIXTRAL_MODEL/
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
### Generate
|
||||
|
||||
As easy as:
|
||||
|
||||
```
|
||||
python mixtral.py --model_path $MIXTRAL_MODEL/
|
||||
```
|
||||
|
||||
For more options including how to prompt the model, run:
|
||||
|
||||
```
|
||||
python mixtral.py --help
|
||||
```
|
||||
|
||||
For the Instruction model, make sure to follow the prompt format:
|
||||
|
||||
```
|
||||
[INST] Instruction prompt [/INST]
|
||||
```
|
||||
|
||||
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
|
||||
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
|
||||
details
|
59
llms/mixtral/convert.py
Normal file
59
llms/mixtral/convert.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def convert(k, v, config):
|
||||
v = v.to(torch.float16).numpy()
|
||||
if "block_sparse_moe" not in k:
|
||||
return [(k, v)]
|
||||
if "gate" in k:
|
||||
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
|
||||
|
||||
# From: layers.N.block_sparse_moe.w
|
||||
# To: layers.N.experts.M.w
|
||||
num_experts = args["moe"]["num_experts"]
|
||||
key_path = k.split(".")
|
||||
v = np.split(v, num_experts, axis=0)
|
||||
if key_path[-1] == "w2":
|
||||
v = [u.T for u in v]
|
||||
|
||||
w_name = key_path.pop()
|
||||
key_path[-1] = "feed_forward.experts"
|
||||
return [
|
||||
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="Mixtral-8x7B-v0.1/",
|
||||
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
model_path = Path(args.model_path)
|
||||
|
||||
with open("params.json") as fid:
|
||||
args = json.load(fid)
|
||||
args["model_type"] = "mixtral"
|
||||
with open(model_path / "config.json", "w") as f:
|
||||
json.dump(args, f, indent=4)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||
for e, tf in enumerate(torch_files):
|
||||
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
|
||||
state = torch.load(tf)
|
||||
new_state = {}
|
||||
for k, v in state.items():
|
||||
new_state.update(convert(k, v, args))
|
||||
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)
|
332
llms/mixtral/mixtral.py
Normal file
332
llms/mixtral/mixtral.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class RoPE(nn.RoPE):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__(dims, traditional)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class MOETransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MOEFeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Mixtral(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h[:, T - 1 : T, :])), cache
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._model.pad_id()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
model_args = ModelArgs(**config)
|
||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mixtral(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mixtral inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="Mixtral-8x7B-v0.1",
|
||||
help="The path to the model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
1
llms/mixtral/params.json
Normal file
1
llms/mixtral/params.json
Normal file
@@ -0,0 +1 @@
|
||||
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}
|
4
llms/mixtral/requirements.txt
Normal file
4
llms/mixtral/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
1
llms/phi2/.gitignore
vendored
Normal file
1
llms/phi2/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
weights.npz
|
62
llms/phi2/README.md
Normal file
62
llms/phi2/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Phi-2
|
||||
|
||||
Phi-2 is a 2.7B parameter language model released by Microsoft with
|
||||
performance that rivals much larger models.[^1] It was trained on a mixture of
|
||||
GPT-4 outputs and clean web text.
|
||||
|
||||
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
|
||||
precision.
|
||||
|
||||
## Setup
|
||||
|
||||
Download and convert the model:
|
||||
|
||||
```sh
|
||||
python convert.py
|
||||
```
|
||||
|
||||
This will make the `weights.npz` file which MLX can read.
|
||||
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
|
||||
## Generate
|
||||
|
||||
To generate text with the default prompt:
|
||||
|
||||
```sh
|
||||
python phi2.py
|
||||
```
|
||||
|
||||
Should give the output:
|
||||
|
||||
```
|
||||
Answer: Mathematics is like a lighthouse that guides us through the darkness of
|
||||
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
|
||||
provides us with a clear path to navigate through complex problems. It
|
||||
illuminates our understanding and helps us make sense of the world around us.
|
||||
|
||||
Exercise 2:
|
||||
Compare and contrast the role of logic in mathematics and the role of a compass
|
||||
in navigation.
|
||||
|
||||
Answer: Logic in mathematics is like a compass in navigation. It helps
|
||||
```
|
||||
|
||||
To use your own prompt:
|
||||
|
||||
```sh
|
||||
python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate>
|
||||
```
|
||||
|
||||
To see a list of options run:
|
||||
|
||||
```sh
|
||||
python phi2.py --help
|
||||
```
|
||||
|
||||
[^1]: For more details on the model see the [blog post](
|
||||
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
|
||||
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
|
24
llms/phi2/convert.py
Normal file
24
llms/phi2/convert.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
if "wte.weight" in key:
|
||||
key = "wte.weight"
|
||||
|
||||
if ".mlp" in key:
|
||||
key = key.replace(".mlp", "")
|
||||
return key
|
||||
|
||||
|
||||
def convert():
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
233
llms/phi2/phi2.py
Normal file
233
llms/phi2/phi2.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
max_sequence_length: int = 2048
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h, cache = self.mixer(h, mask, cache)
|
||||
ff_h = self.fc2(self.act(self.fc1(h)))
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.ln = LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Phi2(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt)
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache=cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
def load_model(model_path: str):
|
||||
model = Phi2(ModelArgs())
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="phi-2",
|
||||
help="The path to the model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
print("[INFO] Generating with Phi-2...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
eos_index = next(
|
||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if eos_index is not None:
|
||||
tokens = tokens[:eos_index]
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
if eos_index is not None:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
5
llms/phi2/requirements.txt
Normal file
5
llms/phi2/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
einops
|
||||
mlx
|
||||
numpy
|
||||
transformers>=4.35
|
||||
torch
|
@@ -1,8 +1,9 @@
|
||||
import argparse
|
||||
from transformers import AutoModelForCausalLM
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
Reference in New Issue
Block a user