mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add support for Phi-3 Medium (#790)
* update to support phi-3 medium * fuse qkv split
This commit is contained in:
parent
b044ce2acf
commit
69700d8431
@ -44,8 +44,9 @@ class Attention(nn.Module):
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
||||
@ -73,7 +74,10 @@ class Attention(nn.Module):
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
query_pos = self.n_heads * self.head_dim
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
Loading…
Reference in New Issue
Block a user