From 69700d84311a630802083a4ace52daf477c9ca9c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 23 May 2024 01:47:06 +0200 Subject: [PATCH] Add support for Phi-3 Medium (#790) * update to support phi-3 medium * fuse qkv split --- llms/mlx_lm/models/phi3.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index e7d38614..3282dff2 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -44,8 +44,9 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.num_hidden_layers = args.num_hidden_layers - head_dim = args.hidden_size // n_heads + self.head_dim = head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) @@ -73,7 +74,10 @@ class Attention(nn.Module): B, L, D = x.shape qkv = self.qkv_proj(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)