mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 22:04:53 +08:00
Add the Llama and Stable Diffusion examples
This commit is contained in:
37
llama/README.md
Normal file
37
llama/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# LLaMA
|
||||
|
||||
An example of generating text with LLaMA using MLX.
|
||||
|
||||
LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to [request
|
||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
from Meta.
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py <path_to_torch_weights> mlx_llama_weights.npz
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can interact with the
|
||||
LLaMA model:
|
||||
|
||||
```
|
||||
python llama.py mlx_llama.npz tokenizer.model "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
46
llama/convert.py
Normal file
46
llama/convert.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
304
llama/llama.py
Normal file
304
llama/llama.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import argparse
|
||||
import math
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.out_proj(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same was as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
# We store the per layer cache in a simple python list
|
||||
cache.append(c)
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
def tic():
|
||||
return time.time()
|
||||
|
||||
|
||||
def toc(msg, start):
|
||||
end = time.time()
|
||||
return f"[INFO] {msg}: {end - start:.3f} s"
|
||||
|
||||
|
||||
def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
# It is perfectly ok to eval things we have already eval-ed.
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
|
||||
|
||||
def few_shot_generate(args):
|
||||
def possible_end(s):
|
||||
word = "[Instruction]"
|
||||
for i in range(len(word) - 1, 0, -1):
|
||||
if s[-i:] == word[:i]:
|
||||
return 0
|
||||
if s[-len(word) :] == word:
|
||||
return 1
|
||||
return -1
|
||||
|
||||
def generate(question):
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
token_list = [t.item() for t in tokens]
|
||||
s = tokenizer.decode(token_list)
|
||||
|
||||
end = possible_end(s)
|
||||
if end == 0:
|
||||
continue
|
||||
if end == 1:
|
||||
skip = len(s)
|
||||
break
|
||||
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
if token_list[-1] == tokenizer.eos_id():
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
|
||||
prompt = open(args.prompt).read().strip()
|
||||
while True:
|
||||
question = input("Ask a question: ")
|
||||
generate(prompt.replace("{}", question))
|
||||
print()
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
weights = mx.load(model_path)
|
||||
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
|
||||
num_heads = dims // 128
|
||||
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
|
||||
vocab_size = weights["out_proj.weight"].shape[-1]
|
||||
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument("model", help="The model file containing MLX weights")
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
action="store_true",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
generate(args)
|
2
llama/requirements.txt
Normal file
2
llama/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
sentencepiece
|
||||
torch
|
23
llama/sample_prompt.txt
Normal file
23
llama/sample_prompt.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
[Instruction] Give the list of U.S. states bordering Canada
|
||||
[Answer] OK, here is the list of U.S. states located on the border with Canada:
|
||||
- Alaska
|
||||
- Michigan
|
||||
- Maine
|
||||
- Minnesota
|
||||
- Montana
|
||||
- New York
|
||||
- Washington
|
||||
- North Dakota
|
||||
- Ohio
|
||||
- Vermont
|
||||
- New Hampshire
|
||||
- Idaho
|
||||
- Pennsylvania
|
||||
[Instruction] Write a paragraph about "functional analysis"
|
||||
[Answer] OK, here is a paragraph on the topic of functional analysis:
|
||||
Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations.
|
||||
[Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business?
|
||||
[Answer] OK, here are two possible names for a new dog walking business:
|
||||
The first option is "Paws on Patrol", and the second option is "The Dog Whisperer".
|
||||
[Instruction] {}
|
||||
[Answer]
|
Reference in New Issue
Block a user