diff --git a/phi2/README.md b/phi2/README.md new file mode 100644 index 00000000..c38f8a74 --- /dev/null +++ b/phi2/README.md @@ -0,0 +1,24 @@ +# Phi-2 + +Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text. +Its performance theoretically rivals much, much stronger models. + +## Downloading and Converting Weights + +To download and convert the model: + +```sh +python phi2/convert.py +``` + +That will fill in `weights/phi-2.npz`. + +## Running the Model + +🚧 (Not yet done) To run the model: + +```sh +python phi2/generate.py +``` + +Layer-by-layer forward pass outputs are currently shown in the outputs.txt files. diff --git a/phi2/__init__.py b/phi2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/phi2/convert.py b/phi2/convert.py new file mode 100644 index 00000000..cd2f77aa --- /dev/null +++ b/phi2/convert.py @@ -0,0 +1,67 @@ +from transformers import AutoModelForCausalLM + +import numpy + + +def split_attention_matrix(state_dict, key) -> dict: + # "transformer.h.0.mixer" + _, model_dim = state_dict[key + ".weight"].shape + # (3 * model_dim, model_dim) + Wqkv_weight_key = key + ".weight" + Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] + Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] + Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] + + # (3 * model_dim) + Wqkv_bias_key = key + ".bias" + Wq_bias = state_dict[Wqkv_bias_key][:model_dim] + Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] + Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] + + out_key = key.replace("mixer.Wqkv", "self_attention") + + return { + out_key + ".query_proj.weight": Wq_weight, + out_key + ".query_proj.bias": Wq_bias, + out_key + ".key_proj.weight": Wk_weight, + out_key + ".key_proj.bias": Wk_bias, + out_key + ".value_proj.weight": Wv_weight, + out_key + ".value_proj.bias": Wv_bias, + } + + +def replace_key(key: str) -> str: + if "wte.weight" in key: + key = "wte.weight" + + if ".mlp" in key: + key = key.replace(".mlp", "") + + if ".mixer.out_proj" in key: + key = key.replace(".mixer", ".self_attention") + + return key + + +def convert(): + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + state_dict = model.state_dict() + keys = list(state_dict.keys()) + + for key in keys: + if ".mixer.Wqkv.weight" not in key: + continue + key_stub = key.rstrip(".weight") + state_dict.update(split_attention_matrix(state_dict, key_stub)) + + del state_dict[key_stub + ".weight"] + del state_dict[key_stub + ".bias"] + + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + numpy.savez("weights/phi-2.npz", **weights) + + +if __name__ == "__main__": + convert() diff --git a/phi2/hf_model.py b/phi2/hf_model.py new file mode 100644 index 00000000..d09ff108 --- /dev/null +++ b/phi2/hf_model.py @@ -0,0 +1,23 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if __name__ == "__main__": + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + + inputs = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="pt", + return_attention_mask=False, + ) + + print(model(**inputs)) + + # outputs = model.generate(**inputs, max_length=200) + # text = tokenizer.batch_decode(outputs)[0] + # print(text) diff --git a/phi2/model.py b/phi2/model.py new file mode 100644 index 00000000..991bf193 --- /dev/null +++ b/phi2/model.py @@ -0,0 +1,232 @@ +from typing import Optional +from dataclasses import dataclass +from mlx.utils import tree_unflatten, tree_map +from transformers import AutoTokenizer + +import mlx.core as mx +import mlx.nn as nn +import math + + +@dataclass +class ModelArgs: + max_sequence_length: int = 2048 + num_vocab: int = 51200 + model_dim: int = 2560 + num_heads: int = 32 + num_layers: int = 32 + rotary_dim: int = 32 + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __call__(self, input: mx.array) -> mx.array: + return ( + 0.5 + * input + * ( + 1.0 + + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) + ) + ) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, num_heads: int, bias: bool = True): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(dims // num_heads, traditional=True) + self.query_proj = nn.Linear(dims, dims, bias=bias) + self.key_proj = nn.Linear(dims, dims, bias=bias) + self.value_proj = nn.Linear(dims, dims, bias=bias) + self.out_proj = nn.Linear(dims, dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + # Note that we return the keys and values to possibly be used as a cache + return self.out_proj(values_hat), (keys, values) + + +class ParallelBlock(nn.Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.self_attention = RoPEAttention(dims, num_heads, bias=True) + self.ln = nn.LayerNorm(dims) + self.fc1 = nn.Linear(dims, mlp_dims) + self.fc2 = nn.Linear(mlp_dims, dims) + self.act = NewGELUActivation() + + def __call__(self, x, x_mask): + residual = x + hidden_states = self.ln(x) + attn_outputs, _ = self.self_attention( + hidden_states, hidden_states, hidden_states, x_mask + ) + ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) + + hidden_states = attn_outputs + ff_hidden_states + residual + + return hidden_states + + +class TransformerDecoder(nn.Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + + def __call__(self, x, x_mask): + for layer in self.h: + x = layer(x, x_mask) + return x + + +class Phi2(nn.Module): + def __init__(self, config: ModelArgs): + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + self.transformer = TransformerDecoder( + num_layers=config.num_layers, + dims=config.model_dim, + num_heads=config.num_heads, + ) + + self.lm_head = LanguageModelingHead(config) + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array = None, + ) -> tuple[mx.array, mx.array]: + x = self.wte(input_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = mx.log(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + else: + attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( + x.shape[1] + ) + + y = self.transformer(x, attention_mask) + return self.lm_head(y) + + def generate(self, input_ids, temp=1.0): + cache = input_ids.tolist() + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) + mask = mask.astype(self.wte.weight.dtype) + + # First we process the prompt x the same way as in __call__ but + # save the caches in cache + x = self.wte(input_ids) + # for l in self.layers: + # x, c = l(x, mask=mask) + # cache.append(c) # <--- we store the per layer cache in a + # simple python list + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) # <--- we only care about the last logits + # that generate the next token + y = mx.random.categorical(y * (1 / temp)) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + cache += [y.item()] + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = self.wte(mx.array(cache)) + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) + y = mx.random.categorical(y * (1 / temp)) + cache += [y[0].item()] + + yield y + + +class LanguageModelingHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) + + +if __name__ == "__main__": + model = Phi2(ModelArgs()) + + weights = mx.load("weights/phi-2.npz") + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: mx.array(p), weights) + + model.update(weights) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + tokens = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="np", + return_attention_mask=False, + ) + + tokens = {key: mx.array(v) for key, v in tokens.items()} + + print( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''' + ) + for output in model.generate(**tokens): + print(tokenizer.decode(output.item())) diff --git a/phi2/phi2_outputs.txt b/phi2/phi2_outputs.txt new file mode 100644 index 00000000..4f27e44b --- /dev/null +++ b/phi2/phi2_outputs.txt @@ -0,0 +1,63 @@ +(HF) Output of Embeddings + +tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075], + [-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058], + [-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008], + ..., + [ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049], + [-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146], + [-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]], + device='cuda:0', dtype=torch.float16, grad_fn=) + +(MLX) Output of Embeddings + +array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444], + [-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504], + [-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437], + ..., + [0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477], + [-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721], + [-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16) + +(HF) Output of First Attention Layer + +tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469], + [ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278], + [ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880], + ..., + [ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255], + [ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064], + [-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]], + device='cuda:0', dtype=torch.float16, grad_fn=) +torch.Size([1, 23, 2560]) + +(MLX) Output of First Attention Layer + +array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074], + [0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176], + [-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785], + ..., + [0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505], + [0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316], + [-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32) +[1, 23, 2560] + +(HF) Overall Output of Inputs: + +tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043], + [ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866], + [ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696], + ..., + [10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754], + [10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104], + [10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]], + +(MLX) Overall Output of Inputs: + +array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392], + [4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547], + [4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838], + ..., + [6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666], + [7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937], + [8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32) \ No newline at end of file