cleanup conversion to use single qkv matrix

This commit is contained in:
Awni Hannun 2023-12-14 09:19:44 -08:00
parent 0c1c500714
commit 8f60d60814
5 changed files with 11 additions and 57 deletions

View File

@ -1,7 +1,7 @@
# Phi-2
Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with
performance that rivals much larger models. It was trained on a mixture of
Phi-2 is a 2.7B parameter language model released by Microsoft with
performance that rivals much larger models.[^1] It was trained on a mixture of
GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit

View File

View File

@ -1,34 +1,5 @@
from transformers import AutoModelForCausalLM
import numpy
def split_attention_matrix(state_dict, key) -> dict:
# "transformer.h.0.mixer"
_, model_dim = state_dict[key + ".weight"].shape
# (3 * model_dim, model_dim)
Wqkv_weight_key = key + ".weight"
Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :]
Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :]
Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :]
# (3 * model_dim)
Wqkv_bias_key = key + ".bias"
Wq_bias = state_dict[Wqkv_bias_key][:model_dim]
Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim]
Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :]
out_key = key.replace("mixer.Wqkv", "self_attention")
return {
out_key + ".query_proj.weight": Wq_weight,
out_key + ".query_proj.bias": Wq_bias,
out_key + ".key_proj.weight": Wk_weight,
out_key + ".key_proj.bias": Wk_bias,
out_key + ".value_proj.weight": Wv_weight,
out_key + ".value_proj.bias": Wv_bias,
}
import numpy as np
def replace_key(key: str) -> str:
if "wte.weight" in key:
@ -36,10 +7,6 @@ def replace_key(key: str) -> str:
if ".mlp" in key:
key = key.replace(".mlp", "")
if ".mixer.out_proj" in key:
key = key.replace(".mixer", ".self_attention")
return key
@ -48,19 +15,8 @@ def convert():
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
state_dict = model.state_dict()
keys = list(state_dict.keys())
for key in keys:
if ".mixer.Wqkv.weight" not in key:
continue
key_stub = key.rstrip(".weight")
state_dict.update(split_attention_matrix(state_dict, key_stub))
del state_dict[key_stub + ".weight"]
del state_dict[key_stub + ".bias"]
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
numpy.savez("weights.npz", **weights)
np.savez("weights.npz", **weights)
if __name__ == "__main__":

View File

@ -31,15 +31,12 @@ class RoPEAttention(nn.Module):
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
@ -81,7 +78,7 @@ class ParallelBlock(nn.Module):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.fc1 = nn.Linear(dims, mlp_dims)
self.fc2 = nn.Linear(mlp_dims, dims)
@ -89,7 +86,7 @@ class ParallelBlock(nn.Module):
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.self_attention(h, h, h, mask, cache)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.fc2(self.act(self.fc1(h)))
return attn_h + ff_h + x, cache

View File

@ -1,3 +1,4 @@
einops
mlx
numpy
transformers