diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index e7d38614..3282dff2 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -44,8 +44,9 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.num_hidden_layers = args.num_hidden_layers - head_dim = args.hidden_size // n_heads + self.head_dim = head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) @@ -73,7 +74,10 @@ class Attention(nn.Module): B, L, D = x.shape qkv = self.qkv_proj(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)