mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-16 23:41:14 +08:00
cleanup conversion to use single qkv matrix
This commit is contained in:
parent
0c1c500714
commit
8f60d60814
@ -1,7 +1,7 @@
|
||||
# Phi-2
|
||||
|
||||
Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with
|
||||
performance that rivals much larger models. It was trained on a mixture of
|
||||
Phi-2 is a 2.7B parameter language model released by Microsoft with
|
||||
performance that rivals much larger models.[^1] It was trained on a mixture of
|
||||
GPT-4 outputs and clean web text.
|
||||
|
||||
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
|
||||
|
@ -1,34 +1,5 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def split_attention_matrix(state_dict, key) -> dict:
|
||||
# "transformer.h.0.mixer"
|
||||
_, model_dim = state_dict[key + ".weight"].shape
|
||||
# (3 * model_dim, model_dim)
|
||||
Wqkv_weight_key = key + ".weight"
|
||||
Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :]
|
||||
Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :]
|
||||
Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :]
|
||||
|
||||
# (3 * model_dim)
|
||||
Wqkv_bias_key = key + ".bias"
|
||||
Wq_bias = state_dict[Wqkv_bias_key][:model_dim]
|
||||
Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim]
|
||||
Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :]
|
||||
|
||||
out_key = key.replace("mixer.Wqkv", "self_attention")
|
||||
|
||||
return {
|
||||
out_key + ".query_proj.weight": Wq_weight,
|
||||
out_key + ".query_proj.bias": Wq_bias,
|
||||
out_key + ".key_proj.weight": Wk_weight,
|
||||
out_key + ".key_proj.bias": Wk_bias,
|
||||
out_key + ".value_proj.weight": Wv_weight,
|
||||
out_key + ".value_proj.bias": Wv_bias,
|
||||
}
|
||||
|
||||
import numpy as np
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
if "wte.weight" in key:
|
||||
@ -36,10 +7,6 @@ def replace_key(key: str) -> str:
|
||||
|
||||
if ".mlp" in key:
|
||||
key = key.replace(".mlp", "")
|
||||
|
||||
if ".mixer.out_proj" in key:
|
||||
key = key.replace(".mixer", ".self_attention")
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@ -48,19 +15,8 @@ def convert():
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
for key in keys:
|
||||
if ".mixer.Wqkv.weight" not in key:
|
||||
continue
|
||||
key_stub = key.rstrip(".weight")
|
||||
state_dict.update(split_attention_matrix(state_dict, key_stub))
|
||||
|
||||
del state_dict[key_stub + ".weight"]
|
||||
del state_dict[key_stub + ".bias"]
|
||||
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
numpy.savez("weights.npz", **weights)
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
15
phi2/phi2.py
15
phi2/phi2.py
@ -31,15 +31,12 @@ class RoPEAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
@ -81,7 +78,7 @@ class ParallelBlock(nn.Module):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||
@ -89,7 +86,7 @@ class ParallelBlock(nn.Module):
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h, cache = self.self_attention(h, h, h, mask, cache)
|
||||
attn_h, cache = self.mixer(h, mask, cache)
|
||||
ff_h = self.fc2(self.act(self.fc1(h)))
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
einops
|
||||
mlx
|
||||
numpy
|
||||
transformers
|
||||
|
Loading…
Reference in New Issue
Block a user