From a466cc51917270308cc6376909dd8d7a3598cf85 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Wed, 13 Dec 2023 22:22:56 -0500 Subject: [PATCH 1/8] phi-2 draft --- phi2/README.md | 24 +++++ phi2/__init__.py | 0 phi2/convert.py | 67 ++++++++++++ phi2/hf_model.py | 23 +++++ phi2/model.py | 232 ++++++++++++++++++++++++++++++++++++++++++ phi2/phi2_outputs.txt | 63 ++++++++++++ 6 files changed, 409 insertions(+) create mode 100644 phi2/README.md create mode 100644 phi2/__init__.py create mode 100644 phi2/convert.py create mode 100644 phi2/hf_model.py create mode 100644 phi2/model.py create mode 100644 phi2/phi2_outputs.txt diff --git a/phi2/README.md b/phi2/README.md new file mode 100644 index 00000000..c38f8a74 --- /dev/null +++ b/phi2/README.md @@ -0,0 +1,24 @@ +# Phi-2 + +Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text. +Its performance theoretically rivals much, much stronger models. + +## Downloading and Converting Weights + +To download and convert the model: + +```sh +python phi2/convert.py +``` + +That will fill in `weights/phi-2.npz`. + +## Running the Model + +🚧 (Not yet done) To run the model: + +```sh +python phi2/generate.py +``` + +Layer-by-layer forward pass outputs are currently shown in the outputs.txt files. diff --git a/phi2/__init__.py b/phi2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/phi2/convert.py b/phi2/convert.py new file mode 100644 index 00000000..cd2f77aa --- /dev/null +++ b/phi2/convert.py @@ -0,0 +1,67 @@ +from transformers import AutoModelForCausalLM + +import numpy + + +def split_attention_matrix(state_dict, key) -> dict: + # "transformer.h.0.mixer" + _, model_dim = state_dict[key + ".weight"].shape + # (3 * model_dim, model_dim) + Wqkv_weight_key = key + ".weight" + Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] + Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] + Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] + + # (3 * model_dim) + Wqkv_bias_key = key + ".bias" + Wq_bias = state_dict[Wqkv_bias_key][:model_dim] + Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] + Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] + + out_key = key.replace("mixer.Wqkv", "self_attention") + + return { + out_key + ".query_proj.weight": Wq_weight, + out_key + ".query_proj.bias": Wq_bias, + out_key + ".key_proj.weight": Wk_weight, + out_key + ".key_proj.bias": Wk_bias, + out_key + ".value_proj.weight": Wv_weight, + out_key + ".value_proj.bias": Wv_bias, + } + + +def replace_key(key: str) -> str: + if "wte.weight" in key: + key = "wte.weight" + + if ".mlp" in key: + key = key.replace(".mlp", "") + + if ".mixer.out_proj" in key: + key = key.replace(".mixer", ".self_attention") + + return key + + +def convert(): + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + state_dict = model.state_dict() + keys = list(state_dict.keys()) + + for key in keys: + if ".mixer.Wqkv.weight" not in key: + continue + key_stub = key.rstrip(".weight") + state_dict.update(split_attention_matrix(state_dict, key_stub)) + + del state_dict[key_stub + ".weight"] + del state_dict[key_stub + ".bias"] + + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + numpy.savez("weights/phi-2.npz", **weights) + + +if __name__ == "__main__": + convert() diff --git a/phi2/hf_model.py b/phi2/hf_model.py new file mode 100644 index 00000000..d09ff108 --- /dev/null +++ b/phi2/hf_model.py @@ -0,0 +1,23 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if __name__ == "__main__": + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + + inputs = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="pt", + return_attention_mask=False, + ) + + print(model(**inputs)) + + # outputs = model.generate(**inputs, max_length=200) + # text = tokenizer.batch_decode(outputs)[0] + # print(text) diff --git a/phi2/model.py b/phi2/model.py new file mode 100644 index 00000000..991bf193 --- /dev/null +++ b/phi2/model.py @@ -0,0 +1,232 @@ +from typing import Optional +from dataclasses import dataclass +from mlx.utils import tree_unflatten, tree_map +from transformers import AutoTokenizer + +import mlx.core as mx +import mlx.nn as nn +import math + + +@dataclass +class ModelArgs: + max_sequence_length: int = 2048 + num_vocab: int = 51200 + model_dim: int = 2560 + num_heads: int = 32 + num_layers: int = 32 + rotary_dim: int = 32 + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __call__(self, input: mx.array) -> mx.array: + return ( + 0.5 + * input + * ( + 1.0 + + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) + ) + ) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, num_heads: int, bias: bool = True): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(dims // num_heads, traditional=True) + self.query_proj = nn.Linear(dims, dims, bias=bias) + self.key_proj = nn.Linear(dims, dims, bias=bias) + self.value_proj = nn.Linear(dims, dims, bias=bias) + self.out_proj = nn.Linear(dims, dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + # Note that we return the keys and values to possibly be used as a cache + return self.out_proj(values_hat), (keys, values) + + +class ParallelBlock(nn.Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.self_attention = RoPEAttention(dims, num_heads, bias=True) + self.ln = nn.LayerNorm(dims) + self.fc1 = nn.Linear(dims, mlp_dims) + self.fc2 = nn.Linear(mlp_dims, dims) + self.act = NewGELUActivation() + + def __call__(self, x, x_mask): + residual = x + hidden_states = self.ln(x) + attn_outputs, _ = self.self_attention( + hidden_states, hidden_states, hidden_states, x_mask + ) + ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) + + hidden_states = attn_outputs + ff_hidden_states + residual + + return hidden_states + + +class TransformerDecoder(nn.Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + + def __call__(self, x, x_mask): + for layer in self.h: + x = layer(x, x_mask) + return x + + +class Phi2(nn.Module): + def __init__(self, config: ModelArgs): + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + self.transformer = TransformerDecoder( + num_layers=config.num_layers, + dims=config.model_dim, + num_heads=config.num_heads, + ) + + self.lm_head = LanguageModelingHead(config) + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array = None, + ) -> tuple[mx.array, mx.array]: + x = self.wte(input_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = mx.log(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + else: + attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( + x.shape[1] + ) + + y = self.transformer(x, attention_mask) + return self.lm_head(y) + + def generate(self, input_ids, temp=1.0): + cache = input_ids.tolist() + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) + mask = mask.astype(self.wte.weight.dtype) + + # First we process the prompt x the same way as in __call__ but + # save the caches in cache + x = self.wte(input_ids) + # for l in self.layers: + # x, c = l(x, mask=mask) + # cache.append(c) # <--- we store the per layer cache in a + # simple python list + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) # <--- we only care about the last logits + # that generate the next token + y = mx.random.categorical(y * (1 / temp)) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + cache += [y.item()] + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = self.wte(mx.array(cache)) + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) + y = mx.random.categorical(y * (1 / temp)) + cache += [y[0].item()] + + yield y + + +class LanguageModelingHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) + + +if __name__ == "__main__": + model = Phi2(ModelArgs()) + + weights = mx.load("weights/phi-2.npz") + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: mx.array(p), weights) + + model.update(weights) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + tokens = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="np", + return_attention_mask=False, + ) + + tokens = {key: mx.array(v) for key, v in tokens.items()} + + print( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''' + ) + for output in model.generate(**tokens): + print(tokenizer.decode(output.item())) diff --git a/phi2/phi2_outputs.txt b/phi2/phi2_outputs.txt new file mode 100644 index 00000000..4f27e44b --- /dev/null +++ b/phi2/phi2_outputs.txt @@ -0,0 +1,63 @@ +(HF) Output of Embeddings + +tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075], + [-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058], + [-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008], + ..., + [ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049], + [-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146], + [-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]], + device='cuda:0', dtype=torch.float16, grad_fn=) + +(MLX) Output of Embeddings + +array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444], + [-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504], + [-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437], + ..., + [0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477], + [-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721], + [-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16) + +(HF) Output of First Attention Layer + +tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469], + [ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278], + [ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880], + ..., + [ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255], + [ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064], + [-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]], + device='cuda:0', dtype=torch.float16, grad_fn=) +torch.Size([1, 23, 2560]) + +(MLX) Output of First Attention Layer + +array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074], + [0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176], + [-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785], + ..., + [0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505], + [0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316], + [-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32) +[1, 23, 2560] + +(HF) Overall Output of Inputs: + +tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043], + [ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866], + [ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696], + ..., + [10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754], + [10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104], + [10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]], + +(MLX) Overall Output of Inputs: + +array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392], + [4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547], + [4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838], + ..., + [6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666], + [7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937], + [8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32) \ No newline at end of file From 88d7b67e6e8dee7a2c128d69223ac0f551aab7a6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 13 Dec 2023 22:26:33 -0800 Subject: [PATCH 2/8] add cache + generation, clean up some stuff --- phi2/.gitignore | 1 + phi2/convert.py | 2 +- phi2/model.py | 177 ++++++++++++++++-------------------------- phi2/requirements.txt | 3 + 4 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 phi2/.gitignore create mode 100644 phi2/requirements.txt diff --git a/phi2/.gitignore b/phi2/.gitignore new file mode 100644 index 00000000..258ec872 --- /dev/null +++ b/phi2/.gitignore @@ -0,0 +1 @@ +weights.npz diff --git a/phi2/convert.py b/phi2/convert.py index cd2f77aa..3c821f69 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -60,7 +60,7 @@ def convert(): del state_dict[key_stub + ".bias"] weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights/phi-2.npz", **weights) + numpy.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/model.py b/phi2/model.py index 991bf193..5253a266 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -7,7 +7,6 @@ import mlx.core as mx import mlx.nn as nn import math - @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -18,23 +17,6 @@ class ModelArgs: rotary_dim: int = 32 -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def __call__(self, input: mx.array) -> mx.array: - return ( - 0.5 - * input - * ( - 1.0 - + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) - ) - ) - - class RoPEAttention(nn.Module): def __init__(self, dims: int, num_heads: int, bias: bool = True): super().__init__() @@ -77,6 +59,7 @@ class RoPEAttention(nn.Module): scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask + scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -92,19 +75,13 @@ class ParallelBlock(nn.Module): self.ln = nn.LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) - self.act = NewGELUActivation() + self.act = nn.GELU(approx="precise") - def __call__(self, x, x_mask): - residual = x - hidden_states = self.ln(x) - attn_outputs, _ = self.self_attention( - hidden_states, hidden_states, hidden_states, x_mask - ) - ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) - - hidden_states = attn_outputs + ff_hidden_states + residual - - return hidden_states + def __call__(self, x, mask, cache): + h = self.ln(x) + attn_h, cache = self.self_attention(h, h, h, mask, cache) + ff_h = self.fc2(self.act(self.fc1(h))) + return attn_h + ff_h + x, cache class TransformerDecoder(nn.Module): @@ -114,10 +91,22 @@ class TransformerDecoder(nn.Module): super().__init__() self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] - def __call__(self, x, x_mask): - for layer in self.h: - x = layer(x, x_mask) - return x + def __call__(self, x, mask, cache): + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) class Phi2(nn.Module): @@ -128,77 +117,40 @@ class Phi2(nn.Module): dims=config.model_dim, num_heads=config.num_heads, ) - - self.lm_head = LanguageModelingHead(config) + self.lm_head = OutputHead(config) def __call__( self, - input_ids: mx.array, - attention_mask: mx.array = None, + inputs: mx.array, + mask: mx.array = None, + cache: mx.array = None, ) -> tuple[mx.array, mx.array]: - x = self.wte(input_ids) + x = self.wte(inputs) - if attention_mask is not None: - # convert 0's to -infs, 1's to 0's, and make it broadcastable - attention_mask = mx.log(attention_mask) - attention_mask = mx.expand_dims(attention_mask, (1, 2)) + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache + + +def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) else: - attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( - x.shape[1] - ) + return mx.random.categorical(logits * (1 / temp)) - y = self.transformer(x, attention_mask) - return self.lm_head(y) + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y - def generate(self, input_ids, temp=1.0): - cache = input_ids.tolist() - - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) - mask = mask.astype(self.wte.weight.dtype) - - # First we process the prompt x the same way as in __call__ but - # save the caches in cache - x = self.wte(input_ids) - # for l in self.layers: - # x, c = l(x, mask=mask) - # cache.append(c) # <--- we store the per layer cache in a - # simple python list - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) # <--- we only care about the last logits - # that generate the next token - y = mx.random.categorical(y * (1 / temp)) - - # y now has size [1] - # Since MLX is lazily evaluated nothing is computed yet. - # Calling y.item() would force the computation to happen at - # this point but we can also choose not to do that and let the - # user choose when to start the computation. + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) yield y - cache += [y.item()] - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = self.wte(mx.array(cache)) - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) - cache += [y[0].item()] - - yield y - - -class LanguageModelingHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) if __name__ == "__main__": @@ -206,27 +158,28 @@ if __name__ == "__main__": weights = mx.load("weights/phi-2.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p), weights) + weights = tree_map(lambda p: mx.array(p, mx.float32), weights) model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - tokens = tokenizer( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''', + prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", return_tensors="np", return_attention_mask=False, - ) + )["input_ids"] - tokens = {key: mx.array(v) for key, v in tokens.items()} + prompt = mx.array(prompt) + + tokens_per_eval = 1 + max_tokens = 100 + + tokens = [] + for token, _ in zip(generate(prompt, model), range(max_tokens)): + tokens.append(token) + + if (len(tokens) % tokens_per_eval) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] - print( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''' - ) - for output in model.generate(**tokens): - print(tokenizer.decode(output.item())) diff --git a/phi2/requirements.txt b/phi2/requirements.txt new file mode 100644 index 00000000..6a11f8d2 --- /dev/null +++ b/phi2/requirements.txt @@ -0,0 +1,3 @@ +einops +mlx +transformers From a8d41491472ffb67081f32ecc4853de7ba1c367c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:08:28 -0800 Subject: [PATCH 3/8] fix fp16 + nits --- phi2/model.py | 97 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index 5253a266..52bda27e 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -1,12 +1,14 @@ +import argparse from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_unflatten, tree_map +from mlx.utils import tree_unflatten from transformers import AutoTokenizer import mlx.core as mx import mlx.nn as nn import math + @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -17,17 +19,22 @@ class ModelArgs: rotary_dim: int = 32 +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, bias: bool = True): + def __init__(self, dims: int, num_heads: int, rotary_dim: int): super().__init__() self.num_heads = num_heads - self.rope = nn.RoPE(dims // num_heads, traditional=True) - self.query_proj = nn.Linear(dims, dims, bias=bias) - self.key_proj = nn.Linear(dims, dims, bias=bias) - self.value_proj = nn.Linear(dims, dims, bias=bias) - self.out_proj = nn.Linear(dims, dims, bias=bias) + self.rope = nn.RoPE(rotary_dim, traditional=False) + self.query_proj = nn.Linear(dims, dims) + self.key_proj = nn.Linear(dims, dims) + self.value_proj = nn.Linear(dims, dims) + self.out_proj = nn.Linear(dims, dims) def __call__(self, queries, keys, values, mask=None, cache=None): queries = self.query_proj(queries) @@ -54,25 +61,28 @@ class RoPEAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask - scores = mx.softmax(scores, axis=-1) + scores = mx.softmax(scores, axis=-1).astype(values.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - # Note that we return the keys and values to possibly be used as a cache return self.out_proj(values_hat), (keys, values) class ParallelBlock(nn.Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__(self, config: ModelArgs): super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.self_attention = RoPEAttention(dims, num_heads, bias=True) - self.ln = nn.LayerNorm(dims) + dims = config.model_dim + mlp_dims = dims * 4 + self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) self.act = nn.GELU(approx="precise") @@ -85,11 +95,9 @@ class ParallelBlock(nn.Module): class TransformerDecoder(nn.Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): + def __init__(self, config: ModelArgs): super().__init__() - self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + self.h = [ParallelBlock(config) for i in range(config.num_layers)] def __call__(self, x, mask, cache): if cache is None: @@ -102,7 +110,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) + self.ln = LayerNorm(config.model_dim) self.linear = nn.Linear(config.model_dim, config.num_vocab) def __call__(self, inputs): @@ -112,11 +120,7 @@ class OutputHead(nn.Module): class Phi2(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.num_vocab, config.model_dim) - self.transformer = TransformerDecoder( - num_layers=config.num_layers, - dims=config.model_dim, - num_heads=config.num_heads, - ) + self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) def __call__( @@ -153,33 +157,58 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): yield y -if __name__ == "__main__": +def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights/phi-2.npz") + weights = mx.load("weights.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p, mx.float32), weights) - model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model() + + prompt = tokenizer( + args.prompt, return_tensors="np", return_attention_mask=False, )["input_ids"] prompt = mx.array(prompt) - tokens_per_eval = 1 - max_tokens = 100 - tokens = [] - for token, _ in zip(generate(prompt, model), range(max_tokens)): + for token, _ in zip(generate(prompt, model), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % tokens_per_eval) == 0: + if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] - From 1613e608a90c80d96055acb0455258235cd31d3a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:18:01 -0800 Subject: [PATCH 4/8] fix args, update README, remove extra files --- phi2/README.md | 44 +++++++++++++++++++++++------- phi2/hf_model.py | 23 ---------------- phi2/model.py | 5 +++- phi2/phi2_outputs.txt | 63 ------------------------------------------- 4 files changed, 38 insertions(+), 97 deletions(-) delete mode 100644 phi2/hf_model.py delete mode 100644 phi2/phi2_outputs.txt diff --git a/phi2/README.md b/phi2/README.md index c38f8a74..46a7c589 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,24 +1,48 @@ # Phi-2 -Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text. -Its performance theoretically rivals much, much stronger models. +Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture +of GPT-4 outputs and clean web-text. Its performance rivals +much, much stronger models. -## Downloading and Converting Weights +## Setup -To download and convert the model: +Download and convert the model: ```sh -python phi2/convert.py +python convert.py ``` -That will fill in `weights/phi-2.npz`. +which will make a file `weights.npz`. -## Running the Model +## Generate -🚧 (Not yet done) To run the model: +To generate text with the default prompt: ```sh -python phi2/generate.py +python model.py ``` -Layer-by-layer forward pass outputs are currently shown in the outputs.txt files. +Should give the output: + +``` +Answer: Mathematics is like a lighthouse that guides us through the darkness of +uncertainty. Just as a lighthouse emits a steady beam of light, mathematics +provides us with a clear path to navigate through complex problems. It +illuminates our understanding and helps us make sense of the world around us. + +Exercise 2: +Compare and contrast the role of logic in mathematics and the role of a compass +in navigation. + +Answer: Logic in mathematics is like a compass in navigation. It helps +``` + +To use your own prompt: + +```sh +python model.py --prompt --max_tokens +``` + +[^1]: For more details on the model see the [blog post]( +https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) +and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) diff --git a/phi2/hf_model.py b/phi2/hf_model.py deleted file mode 100644 index d09ff108..00000000 --- a/phi2/hf_model.py +++ /dev/null @@ -1,23 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer - - -if __name__ == "__main__": - model = AutoModelForCausalLM.from_pretrained( - "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True - ) - tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - - inputs = tokenizer( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''', - return_tensors="pt", - return_attention_mask=False, - ) - - print(model(**inputs)) - - # outputs = model.generate(**inputs, max_length=200) - # text = tokenizer.batch_decode(outputs)[0] - # print(text) diff --git a/phi2/model.py b/phi2/model.py index 52bda27e..a99d3d5d 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -203,11 +203,14 @@ if __name__ == "__main__": prompt = mx.array(prompt) + print("[INFO] Generating with Phi-2...", flush=True) + print(args.prompt, end="", flush=True) + tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % args.tokens_per_eval) == 0: + if (len(tokens) % 10) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) diff --git a/phi2/phi2_outputs.txt b/phi2/phi2_outputs.txt deleted file mode 100644 index 4f27e44b..00000000 --- a/phi2/phi2_outputs.txt +++ /dev/null @@ -1,63 +0,0 @@ -(HF) Output of Embeddings - -tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075], - [-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058], - [-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008], - ..., - [ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049], - [-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146], - [-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]], - device='cuda:0', dtype=torch.float16, grad_fn=) - -(MLX) Output of Embeddings - -array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444], - [-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504], - [-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437], - ..., - [0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477], - [-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721], - [-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16) - -(HF) Output of First Attention Layer - -tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469], - [ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278], - [ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880], - ..., - [ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255], - [ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064], - [-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]], - device='cuda:0', dtype=torch.float16, grad_fn=) -torch.Size([1, 23, 2560]) - -(MLX) Output of First Attention Layer - -array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074], - [0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176], - [-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785], - ..., - [0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505], - [0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316], - [-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32) -[1, 23, 2560] - -(HF) Overall Output of Inputs: - -tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043], - [ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866], - [ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696], - ..., - [10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754], - [10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104], - [10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]], - -(MLX) Overall Output of Inputs: - -array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392], - [4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547], - [4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838], - ..., - [6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666], - [7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937], - [8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32) \ No newline at end of file From 840c0c36c29baec53449100883183789310a2ae1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:27:44 -0800 Subject: [PATCH 5/8] don't drop last tokens --- phi2/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index a99d3d5d..38199c6c 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -159,11 +159,8 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights.npz") - weights = tree_unflatten(list(weights.items())) - model.update(weights) - + model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -215,3 +212,7 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) From 3d2a23184a3530fa277067148a811b759675e6d8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:34:24 -0800 Subject: [PATCH 6/8] change file name for consistency, update readme. --- phi2/README.md | 17 ++++++++++++----- phi2/{model.py => phi2.py} | 0 2 files changed, 12 insertions(+), 5 deletions(-) rename phi2/{model.py => phi2.py} (100%) diff --git a/phi2/README.md b/phi2/README.md index 46a7c589..aef47cd1 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,8 +1,9 @@ # Phi-2 Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture -of GPT-4 outputs and clean web-text. Its performance rivals -much, much stronger models. +of GPT-4 outputs and clean web-text. Its performance rivals much larger models. + +Phi-2 efficiently runs on an Apple silicon device with 8 GB memory in 16-bit precision. ## Setup @@ -12,14 +13,14 @@ Download and convert the model: python convert.py ``` -which will make a file `weights.npz`. +This will make the `weights.npz` file which MLX can read. ## Generate To generate text with the default prompt: ```sh -python model.py +python phi2.py ``` Should give the output: @@ -40,7 +41,13 @@ Answer: Logic in mathematics is like a compass in navigation. It helps To use your own prompt: ```sh -python model.py --prompt --max_tokens +python phi2.py --prompt --max_tokens +``` + +To see a list of options run: + +```sh +python phi2.py --help ``` [^1]: For more details on the model see the [blog post]( diff --git a/phi2/model.py b/phi2/phi2.py similarity index 100% rename from phi2/model.py rename to phi2/phi2.py From 0c1c500714aef1e05d3b1e032dda48667216fdd3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:37:34 -0800 Subject: [PATCH 7/8] update readme --- phi2/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/phi2/README.md b/phi2/README.md index aef47cd1..198ac30c 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,9 +1,11 @@ # Phi-2 -Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture -of GPT-4 outputs and clean web-text. Its performance rivals much larger models. +Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with +performance that rivals much larger models. It was trained on a mixture of +GPT-4 outputs and clean web text. -Phi-2 efficiently runs on an Apple silicon device with 8 GB memory in 16-bit precision. +Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit +precision. ## Setup From 8f60d60814115659c1d9d6f911c7177a66e077e4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 09:19:44 -0800 Subject: [PATCH 8/8] cleanup conversion to use single qkv matrix --- phi2/README.md | 4 ++-- phi2/__init__.py | 0 phi2/convert.py | 48 ++----------------------------------------- phi2/phi2.py | 15 ++++++-------- phi2/requirements.txt | 1 + 5 files changed, 11 insertions(+), 57 deletions(-) delete mode 100644 phi2/__init__.py diff --git a/phi2/README.md b/phi2/README.md index 198ac30c..f5d80696 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,7 +1,7 @@ # Phi-2 -Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with -performance that rivals much larger models. It was trained on a mixture of +Phi-2 is a 2.7B parameter language model released by Microsoft with +performance that rivals much larger models.[^1] It was trained on a mixture of GPT-4 outputs and clean web text. Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit diff --git a/phi2/__init__.py b/phi2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/phi2/convert.py b/phi2/convert.py index 3c821f69..4c625a6e 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -1,34 +1,5 @@ from transformers import AutoModelForCausalLM - -import numpy - - -def split_attention_matrix(state_dict, key) -> dict: - # "transformer.h.0.mixer" - _, model_dim = state_dict[key + ".weight"].shape - # (3 * model_dim, model_dim) - Wqkv_weight_key = key + ".weight" - Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] - Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] - Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] - - # (3 * model_dim) - Wqkv_bias_key = key + ".bias" - Wq_bias = state_dict[Wqkv_bias_key][:model_dim] - Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] - Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] - - out_key = key.replace("mixer.Wqkv", "self_attention") - - return { - out_key + ".query_proj.weight": Wq_weight, - out_key + ".query_proj.bias": Wq_bias, - out_key + ".key_proj.weight": Wk_weight, - out_key + ".key_proj.bias": Wk_bias, - out_key + ".value_proj.weight": Wv_weight, - out_key + ".value_proj.bias": Wv_bias, - } - +import numpy as np def replace_key(key: str) -> str: if "wte.weight" in key: @@ -36,10 +7,6 @@ def replace_key(key: str) -> str: if ".mlp" in key: key = key.replace(".mlp", "") - - if ".mixer.out_proj" in key: - key = key.replace(".mixer", ".self_attention") - return key @@ -48,19 +15,8 @@ def convert(): "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True ) state_dict = model.state_dict() - keys = list(state_dict.keys()) - - for key in keys: - if ".mixer.Wqkv.weight" not in key: - continue - key_stub = key.rstrip(".weight") - state_dict.update(split_attention_matrix(state_dict, key_stub)) - - del state_dict[key_stub + ".weight"] - del state_dict[key_stub + ".bias"] - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights.npz", **weights) + np.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/phi2.py b/phi2/phi2.py index 38199c6c..7973c33d 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -31,15 +31,12 @@ class RoPEAttention(nn.Module): self.num_heads = num_heads self.rope = nn.RoPE(rotary_dim, traditional=False) - self.query_proj = nn.Linear(dims, dims) - self.key_proj = nn.Linear(dims, dims) - self.value_proj = nn.Linear(dims, dims) + self.Wqkv = nn.Linear(dims, 3 * dims) self.out_proj = nn.Linear(dims, dims) - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) + def __call__(self, x, mask=None, cache=None): + qkv = self.Wqkv(x) + queries, keys, values = mx.split(qkv, 3, axis=-1) # Extract some shapes num_heads = self.num_heads @@ -81,7 +78,7 @@ class ParallelBlock(nn.Module): super().__init__() dims = config.model_dim mlp_dims = dims * 4 - self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) @@ -89,7 +86,7 @@ class ParallelBlock(nn.Module): def __call__(self, x, mask, cache): h = self.ln(x) - attn_h, cache = self.self_attention(h, h, h, mask, cache) + attn_h, cache = self.mixer(h, mask, cache) ff_h = self.fc2(self.act(self.fc1(h))) return attn_h + ff_h + x, cache diff --git a/phi2/requirements.txt b/phi2/requirements.txt index 6a11f8d2..3e141ec3 100644 --- a/phi2/requirements.txt +++ b/phi2/requirements.txt @@ -1,3 +1,4 @@ einops mlx +numpy transformers