mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| import math
 | |
| import numpy as np
 | |
| from sentencepiece import SentencePieceProcessor
 | |
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| from mlx.utils import tree_unflatten
 | |
| 
 | |
| 
 | |
| class LlamaAttention(nn.Module):
 | |
|     def __init__(self, dims: int, num_heads: int):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.num_heads = num_heads
 | |
| 
 | |
|         self.rope = nn.RoPE(dims // num_heads, traditional=True)
 | |
|         self.query_proj = nn.Linear(dims, dims, bias=False)
 | |
|         self.key_proj = nn.Linear(dims, dims, bias=False)
 | |
|         self.value_proj = nn.Linear(dims, dims, bias=False)
 | |
|         self.out_proj = nn.Linear(dims, dims, bias=False)
 | |
| 
 | |
|     def __call__(self, queries, keys, values, mask=None, cache=None):
 | |
|         queries = self.query_proj(queries)
 | |
|         keys = self.key_proj(keys)
 | |
|         values = self.value_proj(values)
 | |
| 
 | |
|         # Extract some shapes
 | |
|         num_heads = self.num_heads
 | |
|         B, L, D = queries.shape
 | |
| 
 | |
|         # Prepare the queries, keys and values for the attention computation
 | |
|         queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
 | |
|         keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
 | |
|         values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
 | |
| 
 | |
|         # Add RoPE to the queries and keys and combine them with the cache
 | |
|         if cache is not None:
 | |
|             key_cache, value_cache = cache
 | |
|             queries = self.rope(queries, offset=key_cache.shape[2])
 | |
|             keys = self.rope(keys, offset=key_cache.shape[2])
 | |
|             keys = mx.concatenate([key_cache, keys], axis=2)
 | |
|             values = mx.concatenate([value_cache, values], axis=2)
 | |
|         else:
 | |
|             queries = self.rope(queries)
 | |
|             keys = self.rope(keys)
 | |
| 
 | |
|         # Finally perform the attention computation
 | |
|         scale = math.sqrt(1 / queries.shape[-1])
 | |
|         scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
 | |
|         if mask is not None:
 | |
|             scores = scores + mask
 | |
|         scores = mx.softmax(scores, axis=-1)
 | |
|         values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
 | |
| 
 | |
|         # Note that we return the keys and values to possibly be used as a cache
 | |
|         return self.out_proj(values_hat), (keys, values)
 | |
| 
 | |
| 
 | |
| class LlamaEncoderLayer(nn.Module):
 | |
|     def __init__(self, dims: int, mlp_dims: int, num_heads: int):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.attention = LlamaAttention(dims, num_heads)
 | |
| 
 | |
|         self.norm1 = nn.RMSNorm(dims)
 | |
|         self.norm2 = nn.RMSNorm(dims)
 | |
| 
 | |
|         self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
 | |
|         self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
 | |
|         self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
 | |
| 
 | |
|     def __call__(self, x, mask=None, cache=None):
 | |
|         y = self.norm1(x)
 | |
|         y, cache = self.attention(y, y, y, mask, cache)
 | |
|         x = x + y
 | |
| 
 | |
|         y = self.norm2(x)
 | |
|         a = self.linear1(y)
 | |
|         b = self.linear2(y)
 | |
|         y = a * mx.sigmoid(a) * b
 | |
|         y = self.linear3(y)
 | |
|         x = x + y
 | |
| 
 | |
|         return x, cache
 | |
| 
 | |
| 
 | |
| class Llama(nn.Module):
 | |
|     def __init__(
 | |
|         self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
 | |
|     ):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.embedding = nn.Embedding(vocab_size, dims)
 | |
|         self.layers = [
 | |
|             LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
 | |
|         ]
 | |
|         self.norm = nn.RMSNorm(dims)
 | |
|         self.out_proj = nn.Linear(dims, vocab_size, bias=False)
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
 | |
|         mask = mask.astype(self.embedding.weight.dtype)
 | |
| 
 | |
|         x = self.embedding(x)
 | |
|         for l in self.layers:
 | |
|             x, _ = l(x, mask)
 | |
|         x = self.norm(x)
 | |
|         return self.out_proj(x)
 | |
| 
 | |
|     def generate(self, x, temp=1.0):
 | |
|         cache = []
 | |
| 
 | |
|         # Make an additive causal mask. We will need that to process the prompt.
 | |
|         mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
 | |
|         mask = mask.astype(self.embedding.weight.dtype)
 | |
| 
 | |
|         # First we process the prompt x the same was as in __call__ but
 | |
|         # save the caches in cache
 | |
|         x = self.embedding(x)
 | |
|         for l in self.layers:
 | |
|             x, c = l(x, mask=mask)
 | |
|             # We store the per layer cache in a simple python list
 | |
|             cache.append(c)
 | |
|         x = self.norm(x)
 | |
|         # We only care about the last logits that generate the next token
 | |
|         y = self.out_proj(x[:, -1])
 | |
|         y = mx.random.categorical(y * (1 / temp))
 | |
| 
 | |
|         # y now has size [1]
 | |
|         # Since MLX is lazily evaluated nothing is computed yet.
 | |
|         # Calling y.item() would force the computation to happen at
 | |
|         # this point but we can also choose not to do that and let the
 | |
|         # user choose when to start the computation.
 | |
|         yield y
 | |
| 
 | |
|         # Now we parsed the prompt and generated the first token we
 | |
|         # need to feed it back into the model and loop to generate the
 | |
|         # rest.
 | |
|         while True:
 | |
|             # Unsqueezing the last dimension to add a sequence length
 | |
|             # dimension of 1
 | |
|             x = y[:, None]
 | |
| 
 | |
|             x = self.embedding(x)
 | |
|             for i in range(len(cache)):
 | |
|                 # We are overwriting the arrays in the cache list. When
 | |
|                 # the computation will happen, MLX will be discarding the
 | |
|                 # old cache the moment it is not needed anymore.
 | |
|                 x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
 | |
|             x = self.norm(x)
 | |
|             y = self.out_proj(x[:, -1])
 | |
|             y = mx.random.categorical(y * (1 / temp))
 | |
| 
 | |
|             yield y
 | |
| 
 | |
| 
 | |
| def tic():
 | |
|     return time.time()
 | |
| 
 | |
| 
 | |
| def toc(msg, start):
 | |
|     end = time.time()
 | |
|     return f"[INFO] {msg}: {end - start:.3f} s"
 | |
| 
 | |
| 
 | |
| def generate(args):
 | |
| 
 | |
|     input("Press enter to start generation")
 | |
|     print("------")
 | |
| 
 | |
|     x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
 | |
|     skip = 0
 | |
|     prompt_processing = None
 | |
|     tokens = []
 | |
|     start = tic()
 | |
|     for token in model.generate(x, args.temp):
 | |
|         tokens.append(token)
 | |
| 
 | |
|         if len(tokens) == 1:
 | |
|             # Actually perform the computation to measure the prompt processing time
 | |
|             mx.eval(token)
 | |
|             prompt_processing = toc("Prompt processing", start)
 | |
| 
 | |
|         if len(tokens) >= args.num_tokens:
 | |
|             break
 | |
| 
 | |
|         elif (len(tokens) % args.write_every) == 0:
 | |
|             # It is perfectly ok to eval things we have already eval-ed.
 | |
|             mx.eval(tokens)
 | |
|             s = tokenizer.decode([t.item() for t in tokens])
 | |
|             print(s[skip:], end="", flush=True)
 | |
|             skip = len(s)
 | |
| 
 | |
|     mx.eval(tokens)
 | |
|     full_gen = toc("Full generation", start)
 | |
|     s = tokenizer.decode([t.item() for t in tokens])
 | |
|     print(s[skip:], end="", flush=True)
 | |
|     print()
 | |
|     print("------")
 | |
|     print(prompt_processing)
 | |
|     print(full_gen)
 | |
| 
 | |
| 
 | |
| def few_shot_generate(args):
 | |
|     def possible_end(s):
 | |
|         word = "[Instruction]"
 | |
|         for i in range(len(word) - 1, 0, -1):
 | |
|             if s[-i:] == word[:i]:
 | |
|                 return 0
 | |
|         if s[-len(word) :] == word:
 | |
|             return 1
 | |
|         return -1
 | |
| 
 | |
|     def generate(question):
 | |
|         x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
 | |
|         skip = 0
 | |
|         prompt_processing = None
 | |
|         tokens = []
 | |
|         start = tic()
 | |
|         for token in model.generate(x, args.temp):
 | |
|             tokens.append(token)
 | |
| 
 | |
|             if len(tokens) == 1:
 | |
|                 # Actually perform the computation to measure the prompt processing time
 | |
|                 mx.eval(token)
 | |
|                 prompt_processing = toc("Prompt processing", start)
 | |
| 
 | |
|             if len(tokens) >= args.num_tokens:
 | |
|                 break
 | |
| 
 | |
|             mx.eval(tokens)
 | |
|             token_list = [t.item() for t in tokens]
 | |
|             s = tokenizer.decode(token_list)
 | |
| 
 | |
|             end = possible_end(s)
 | |
|             if end == 0:
 | |
|                 continue
 | |
|             if end == 1:
 | |
|                 skip = len(s)
 | |
|                 break
 | |
| 
 | |
|             print(s[skip:], end="", flush=True)
 | |
|             skip = len(s)
 | |
|             if token_list[-1] == tokenizer.eos_id():
 | |
|                 break
 | |
| 
 | |
|         mx.eval(tokens)
 | |
|         full_gen = toc("Full generation", start)
 | |
|         s = tokenizer.decode([t.item() for t in tokens])
 | |
|         print(s[skip:], end="", flush=True)
 | |
| 
 | |
|     prompt = open(args.prompt).read().strip()
 | |
|     while True:
 | |
|         question = input("Ask a question: ")
 | |
|         generate(prompt.replace("{}", question))
 | |
|         print()
 | |
| 
 | |
| 
 | |
| def load_model(model_path):
 | |
|     weights = mx.load(model_path)
 | |
|     mlp_dims, dims = weights["layers.0.linear1.weight"].shape
 | |
|     num_heads = dims // 128
 | |
|     num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
 | |
|     vocab_size = weights["out_proj.weight"].shape[-1]
 | |
|     model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
 | |
|     model.update(tree_unflatten(list(weights.items())))
 | |
|     mx.eval(model.parameters())
 | |
|     return model
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Llama inference script")
 | |
|     parser.add_argument("model", help="The model file containing MLX weights")
 | |
|     parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
 | |
|     parser.add_argument("prompt", help="The message to be processed by the model")
 | |
|     parser.add_argument(
 | |
|         "--few-shot",
 | |
|         action="store_true",
 | |
|         help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--write-every", type=int, default=1, help="After how many tokens to detokenize"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--temp", type=float, default=0.8, help="The sampling temperature"
 | |
|     )
 | |
|     parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     mx.random.seed(args.seed)
 | |
| 
 | |
|     tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
 | |
|     print("[INFO] Loading model from disk.")
 | |
|     model = load_model(args.model)
 | |
|     if args.few_shot:
 | |
|         few_shot_generate(args)
 | |
|     else:
 | |
|         generate(args)
 | 
