mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
		
							
								
								
									
										1
									
								
								llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| weights.npz | ||||
							
								
								
									
										62
									
								
								llms/phi2/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								llms/phi2/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| # Phi-2 | ||||
|  | ||||
| Phi-2 is a 2.7B parameter language model released by Microsoft with | ||||
| performance that rivals much larger models.[^1] It was trained on a mixture of | ||||
| GPT-4 outputs and clean web text. | ||||
|  | ||||
| Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit | ||||
| precision. | ||||
|  | ||||
| ## Setup  | ||||
|  | ||||
| Download and convert the model: | ||||
|  | ||||
| ```sh  | ||||
| python convert.py | ||||
| ``` | ||||
|  | ||||
| This will make the `weights.npz` file which MLX can read. | ||||
|  | ||||
| > [!TIP] Alternatively, you can also download a few converted checkpoints from | ||||
| > the [MLX Community](https://huggingface.co/mlx-community) organization on | ||||
| > Hugging Face and skip the conversion step. | ||||
|  | ||||
|  | ||||
| ## Generate  | ||||
|  | ||||
| To generate text with the default prompt: | ||||
|  | ||||
| ```sh | ||||
| python phi2.py | ||||
| ``` | ||||
|  | ||||
| Should give the output: | ||||
|  | ||||
| ``` | ||||
| Answer: Mathematics is like a lighthouse that guides us through the darkness of | ||||
| uncertainty. Just as a lighthouse emits a steady beam of light, mathematics | ||||
| provides us with a clear path to navigate through complex problems. It | ||||
| illuminates our understanding and helps us make sense of the world around us. | ||||
|  | ||||
| Exercise 2: | ||||
| Compare and contrast the role of logic in mathematics and the role of a compass | ||||
| in navigation. | ||||
|  | ||||
| Answer: Logic in mathematics is like a compass in navigation. It helps | ||||
| ``` | ||||
|  | ||||
| To use your own prompt: | ||||
|  | ||||
| ```sh | ||||
| python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate> | ||||
| ``` | ||||
|  | ||||
| To see a list of options run: | ||||
|  | ||||
| ```sh | ||||
| python phi2.py --help | ||||
| ``` | ||||
|  | ||||
| [^1]: For more details on the model see the [blog post]( | ||||
| https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) | ||||
| and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) | ||||
							
								
								
									
										24
									
								
								llms/phi2/convert.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								llms/phi2/convert.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| import numpy as np | ||||
| from transformers import AutoModelForCausalLM | ||||
|  | ||||
|  | ||||
| def replace_key(key: str) -> str: | ||||
|     if "wte.weight" in key: | ||||
|         key = "wte.weight" | ||||
|  | ||||
|     if ".mlp" in key: | ||||
|         key = key.replace(".mlp", "") | ||||
|     return key | ||||
|  | ||||
|  | ||||
| def convert(): | ||||
|     model = AutoModelForCausalLM.from_pretrained( | ||||
|         "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True | ||||
|     ) | ||||
|     state_dict = model.state_dict() | ||||
|     weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} | ||||
|     np.savez("weights.npz", **weights) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     convert() | ||||
							
								
								
									
										233
									
								
								llms/phi2/phi2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								llms/phi2/phi2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,233 @@ | ||||
| import argparse | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_unflatten | ||||
| from transformers import AutoTokenizer | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ModelArgs: | ||||
|     max_sequence_length: int = 2048 | ||||
|     num_vocab: int = 51200 | ||||
|     model_dim: int = 2560 | ||||
|     num_heads: int = 32 | ||||
|     num_layers: int = 32 | ||||
|     rotary_dim: int = 32 | ||||
|  | ||||
|  | ||||
| class LayerNorm(nn.LayerNorm): | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         return super().__call__(x.astype(mx.float32)).astype(x.dtype) | ||||
|  | ||||
|  | ||||
| class RoPEAttention(nn.Module): | ||||
|     def __init__(self, dims: int, num_heads: int, rotary_dim: int): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.num_heads = num_heads | ||||
|  | ||||
|         self.rope = nn.RoPE(rotary_dim, traditional=False) | ||||
|         self.Wqkv = nn.Linear(dims, 3 * dims) | ||||
|         self.out_proj = nn.Linear(dims, dims) | ||||
|  | ||||
|     def __call__(self, x, mask=None, cache=None): | ||||
|         qkv = self.Wqkv(x) | ||||
|         queries, keys, values = mx.split(qkv, 3, axis=-1) | ||||
|  | ||||
|         # Extract some shapes | ||||
|         num_heads = self.num_heads | ||||
|         B, L, D = queries.shape | ||||
|  | ||||
|         # Prepare the queries, keys and values for the attention computation | ||||
|         queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) | ||||
|         keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) | ||||
|         values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) | ||||
|  | ||||
|         # Add RoPE to the queries and keys and combine them with the cache | ||||
|         if cache is not None: | ||||
|             key_cache, value_cache = cache | ||||
|             queries = self.rope(queries, offset=key_cache.shape[2]) | ||||
|             keys = self.rope(keys, offset=key_cache.shape[2]) | ||||
|             keys = mx.concatenate([key_cache, keys], axis=2) | ||||
|             values = mx.concatenate([value_cache, values], axis=2) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         queries = queries.astype(mx.float32) | ||||
|         keys = keys.astype(mx.float32) | ||||
|  | ||||
|         # Finally perform the attention computation | ||||
|         scale = math.sqrt(1 / queries.shape[-1]) | ||||
|         scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) | ||||
|         if mask is not None: | ||||
|             scores = scores + mask | ||||
|  | ||||
|         scores = mx.softmax(scores, axis=-1).astype(values.dtype) | ||||
|         values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) | ||||
|  | ||||
|         return self.out_proj(values_hat), (keys, values) | ||||
|  | ||||
|  | ||||
| class ParallelBlock(nn.Module): | ||||
|     def __init__(self, config: ModelArgs): | ||||
|         super().__init__() | ||||
|         dims = config.model_dim | ||||
|         mlp_dims = dims * 4 | ||||
|         self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) | ||||
|         self.ln = LayerNorm(dims) | ||||
|         self.fc1 = nn.Linear(dims, mlp_dims) | ||||
|         self.fc2 = nn.Linear(mlp_dims, dims) | ||||
|         self.act = nn.GELU(approx="precise") | ||||
|  | ||||
|     def __call__(self, x, mask, cache): | ||||
|         h = self.ln(x) | ||||
|         attn_h, cache = self.mixer(h, mask, cache) | ||||
|         ff_h = self.fc2(self.act(self.fc1(h))) | ||||
|         return attn_h + ff_h + x, cache | ||||
|  | ||||
|  | ||||
| class TransformerDecoder(nn.Module): | ||||
|     def __init__(self, config: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.h = [ParallelBlock(config) for i in range(config.num_layers)] | ||||
|  | ||||
|     def __call__(self, x, mask, cache): | ||||
|         if cache is None: | ||||
|             cache = [None] * len(self.h) | ||||
|  | ||||
|         for e, layer in enumerate(self.h): | ||||
|             x, cache[e] = layer(x, mask, cache[e]) | ||||
|         return x, cache | ||||
|  | ||||
|  | ||||
| class OutputHead(nn.Module): | ||||
|     def __init__(self, config: ModelArgs) -> None: | ||||
|         self.ln = LayerNorm(config.model_dim) | ||||
|         self.linear = nn.Linear(config.model_dim, config.num_vocab) | ||||
|  | ||||
|     def __call__(self, inputs): | ||||
|         return self.linear(self.ln(inputs)) | ||||
|  | ||||
|  | ||||
| class Phi2(nn.Module): | ||||
|     def __init__(self, config: ModelArgs): | ||||
|         self.wte = nn.Embedding(config.num_vocab, config.model_dim) | ||||
|         self.transformer = TransformerDecoder(config) | ||||
|         self.lm_head = OutputHead(config) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         inputs: mx.array, | ||||
|         mask: mx.array = None, | ||||
|         cache: mx.array = None, | ||||
|     ) -> tuple[mx.array, mx.array]: | ||||
|         x = self.wte(inputs) | ||||
|  | ||||
|         mask = None | ||||
|         if x.shape[1] > 1: | ||||
|             mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) | ||||
|             mask = mask.astype(x.dtype) | ||||
|  | ||||
|         y, cache = self.transformer(x, mask, cache) | ||||
|         return self.lm_head(y), cache | ||||
|  | ||||
|  | ||||
| def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): | ||||
|     def sample(logits): | ||||
|         if temp == 0: | ||||
|             return mx.argmax(logits, axis=-1) | ||||
|         else: | ||||
|             return mx.random.categorical(logits * (1 / temp)) | ||||
|  | ||||
|     logits, cache = model(prompt) | ||||
|     y = sample(logits[:, -1, :]) | ||||
|     yield y | ||||
|  | ||||
|     while True: | ||||
|         logits, cache = model(y[:, None], cache=cache) | ||||
|         y = sample(logits.squeeze(1)) | ||||
|         yield y | ||||
|  | ||||
|  | ||||
| def load_model(model_path: str): | ||||
|     model = Phi2(ModelArgs()) | ||||
|     model_path = Path(model_path) | ||||
|     weights = mx.load(str(model_path / "weights.npz")) | ||||
|     model.update(tree_unflatten(list(weights.items()))) | ||||
|     tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) | ||||
|     return model, tokenizer | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Phi-2 inference script") | ||||
|     parser.add_argument( | ||||
|         "--model-path", | ||||
|         type=str, | ||||
|         default="phi-2", | ||||
|         help="The path to the model weights", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prompt", | ||||
|         help="The message to be processed by the model", | ||||
|         default="Write a detailed analogy between mathematics and a lighthouse.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max-tokens", | ||||
|         "-m", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="Maximum number of tokens to generate", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--temp", | ||||
|         help="The sampling temperature.", | ||||
|         type=float, | ||||
|         default=0.0, | ||||
|     ) | ||||
|     parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     mx.random.seed(args.seed) | ||||
|  | ||||
|     model, tokenizer = load_model(args.model_path) | ||||
|  | ||||
|     prompt = tokenizer( | ||||
|         args.prompt, | ||||
|         return_tensors="np", | ||||
|         return_attention_mask=False, | ||||
|     )["input_ids"] | ||||
|  | ||||
|     prompt = mx.array(prompt) | ||||
|  | ||||
|     print("[INFO] Generating with Phi-2...", flush=True) | ||||
|     print(args.prompt, end="", flush=True) | ||||
|  | ||||
|     tokens = [] | ||||
|     for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): | ||||
|         tokens.append(token) | ||||
|  | ||||
|         if (len(tokens) % 10) == 0: | ||||
|             mx.eval(tokens) | ||||
|             eos_index = next( | ||||
|                 (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|             if eos_index is not None: | ||||
|                 tokens = tokens[:eos_index] | ||||
|  | ||||
|             s = tokenizer.decode([t.item() for t in tokens]) | ||||
|             print(s, end="", flush=True) | ||||
|             tokens = [] | ||||
|             if eos_index is not None: | ||||
|                 break | ||||
|  | ||||
|     mx.eval(tokens) | ||||
|     s = tokenizer.decode([t.item() for t in tokens]) | ||||
|     print(s, flush=True) | ||||
							
								
								
									
										5
									
								
								llms/phi2/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								llms/phi2/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| einops | ||||
| mlx | ||||
| numpy | ||||
| transformers>=4.35 | ||||
| torch | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun