From a8d41491472ffb67081f32ecc4853de7ba1c367c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:08:28 -0800 Subject: [PATCH] fix fp16 + nits --- phi2/model.py | 97 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index 5253a266..52bda27e 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -1,12 +1,14 @@ +import argparse from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_unflatten, tree_map +from mlx.utils import tree_unflatten from transformers import AutoTokenizer import mlx.core as mx import mlx.nn as nn import math + @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -17,17 +19,22 @@ class ModelArgs: rotary_dim: int = 32 +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, bias: bool = True): + def __init__(self, dims: int, num_heads: int, rotary_dim: int): super().__init__() self.num_heads = num_heads - self.rope = nn.RoPE(dims // num_heads, traditional=True) - self.query_proj = nn.Linear(dims, dims, bias=bias) - self.key_proj = nn.Linear(dims, dims, bias=bias) - self.value_proj = nn.Linear(dims, dims, bias=bias) - self.out_proj = nn.Linear(dims, dims, bias=bias) + self.rope = nn.RoPE(rotary_dim, traditional=False) + self.query_proj = nn.Linear(dims, dims) + self.key_proj = nn.Linear(dims, dims) + self.value_proj = nn.Linear(dims, dims) + self.out_proj = nn.Linear(dims, dims) def __call__(self, queries, keys, values, mask=None, cache=None): queries = self.query_proj(queries) @@ -54,25 +61,28 @@ class RoPEAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask - scores = mx.softmax(scores, axis=-1) + scores = mx.softmax(scores, axis=-1).astype(values.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - # Note that we return the keys and values to possibly be used as a cache return self.out_proj(values_hat), (keys, values) class ParallelBlock(nn.Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__(self, config: ModelArgs): super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.self_attention = RoPEAttention(dims, num_heads, bias=True) - self.ln = nn.LayerNorm(dims) + dims = config.model_dim + mlp_dims = dims * 4 + self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) self.act = nn.GELU(approx="precise") @@ -85,11 +95,9 @@ class ParallelBlock(nn.Module): class TransformerDecoder(nn.Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): + def __init__(self, config: ModelArgs): super().__init__() - self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + self.h = [ParallelBlock(config) for i in range(config.num_layers)] def __call__(self, x, mask, cache): if cache is None: @@ -102,7 +110,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) + self.ln = LayerNorm(config.model_dim) self.linear = nn.Linear(config.model_dim, config.num_vocab) def __call__(self, inputs): @@ -112,11 +120,7 @@ class OutputHead(nn.Module): class Phi2(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.num_vocab, config.model_dim) - self.transformer = TransformerDecoder( - num_layers=config.num_layers, - dims=config.model_dim, - num_heads=config.num_heads, - ) + self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) def __call__( @@ -153,33 +157,58 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): yield y -if __name__ == "__main__": +def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights/phi-2.npz") + weights = mx.load("weights.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p, mx.float32), weights) - model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model() + + prompt = tokenizer( + args.prompt, return_tensors="np", return_attention_mask=False, )["input_ids"] prompt = mx.array(prompt) - tokens_per_eval = 1 - max_tokens = 100 - tokens = [] - for token, _ in zip(generate(prompt, model), range(max_tokens)): + for token, _ in zip(generate(prompt, model), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % tokens_per_eval) == 0: + if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] -