From 8f60d60814115659c1d9d6f911c7177a66e077e4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 09:19:44 -0800 Subject: [PATCH] cleanup conversion to use single qkv matrix --- phi2/README.md | 4 ++-- phi2/__init__.py | 0 phi2/convert.py | 48 ++----------------------------------------- phi2/phi2.py | 15 ++++++-------- phi2/requirements.txt | 1 + 5 files changed, 11 insertions(+), 57 deletions(-) delete mode 100644 phi2/__init__.py diff --git a/phi2/README.md b/phi2/README.md index 198ac30c..f5d80696 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,7 +1,7 @@ # Phi-2 -Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with -performance that rivals much larger models. It was trained on a mixture of +Phi-2 is a 2.7B parameter language model released by Microsoft with +performance that rivals much larger models.[^1] It was trained on a mixture of GPT-4 outputs and clean web text. Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit diff --git a/phi2/__init__.py b/phi2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/phi2/convert.py b/phi2/convert.py index 3c821f69..4c625a6e 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -1,34 +1,5 @@ from transformers import AutoModelForCausalLM - -import numpy - - -def split_attention_matrix(state_dict, key) -> dict: - # "transformer.h.0.mixer" - _, model_dim = state_dict[key + ".weight"].shape - # (3 * model_dim, model_dim) - Wqkv_weight_key = key + ".weight" - Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] - Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] - Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] - - # (3 * model_dim) - Wqkv_bias_key = key + ".bias" - Wq_bias = state_dict[Wqkv_bias_key][:model_dim] - Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] - Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] - - out_key = key.replace("mixer.Wqkv", "self_attention") - - return { - out_key + ".query_proj.weight": Wq_weight, - out_key + ".query_proj.bias": Wq_bias, - out_key + ".key_proj.weight": Wk_weight, - out_key + ".key_proj.bias": Wk_bias, - out_key + ".value_proj.weight": Wv_weight, - out_key + ".value_proj.bias": Wv_bias, - } - +import numpy as np def replace_key(key: str) -> str: if "wte.weight" in key: @@ -36,10 +7,6 @@ def replace_key(key: str) -> str: if ".mlp" in key: key = key.replace(".mlp", "") - - if ".mixer.out_proj" in key: - key = key.replace(".mixer", ".self_attention") - return key @@ -48,19 +15,8 @@ def convert(): "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True ) state_dict = model.state_dict() - keys = list(state_dict.keys()) - - for key in keys: - if ".mixer.Wqkv.weight" not in key: - continue - key_stub = key.rstrip(".weight") - state_dict.update(split_attention_matrix(state_dict, key_stub)) - - del state_dict[key_stub + ".weight"] - del state_dict[key_stub + ".bias"] - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights.npz", **weights) + np.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/phi2.py b/phi2/phi2.py index 38199c6c..7973c33d 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -31,15 +31,12 @@ class RoPEAttention(nn.Module): self.num_heads = num_heads self.rope = nn.RoPE(rotary_dim, traditional=False) - self.query_proj = nn.Linear(dims, dims) - self.key_proj = nn.Linear(dims, dims) - self.value_proj = nn.Linear(dims, dims) + self.Wqkv = nn.Linear(dims, 3 * dims) self.out_proj = nn.Linear(dims, dims) - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) + def __call__(self, x, mask=None, cache=None): + qkv = self.Wqkv(x) + queries, keys, values = mx.split(qkv, 3, axis=-1) # Extract some shapes num_heads = self.num_heads @@ -81,7 +78,7 @@ class ParallelBlock(nn.Module): super().__init__() dims = config.model_dim mlp_dims = dims * 4 - self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) @@ -89,7 +86,7 @@ class ParallelBlock(nn.Module): def __call__(self, x, mask, cache): h = self.ln(x) - attn_h, cache = self.self_attention(h, h, h, mask, cache) + attn_h, cache = self.mixer(h, mask, cache) ff_h = self.fc2(self.act(self.fc1(h))) return attn_h + ff_h + x, cache diff --git a/phi2/requirements.txt b/phi2/requirements.txt index 6a11f8d2..3e141ec3 100644 --- a/phi2/requirements.txt +++ b/phi2/requirements.txt @@ -1,3 +1,4 @@ einops mlx +numpy transformers