Add support for Phi-3 Medium (#790)

* update to support phi-3 medium

* fuse qkv split
This commit is contained in:
Prince Canuma 2024-05-23 01:47:06 +02:00 committed by GitHub
parent b044ce2acf
commit 69700d8431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,8 +44,9 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
head_dim = args.hidden_size // n_heads self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
@ -73,7 +74,10 @@ class Attention(nn.Module):
B, L, D = x.shape B, L, D = x.shape
qkv = self.qkv_proj(x) qkv = self.qkv_proj(x)
queries, keys, values = mx.split(qkv, 3, axis=-1) query_pos = self.n_heads * self.head_dim
queries, keys, values = mx.split(
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation # Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)