From 7612c646f3957a5d588dddb115c49e88a4c78058 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Sun, 13 Oct 2024 07:26:50 +0900 Subject: [PATCH] Fix PLaMo model to support Grouped Query Attention (#1037) --- llms/mlx_lm/models/plamo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 090922ae..b0fd1a6c 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -89,6 +89,9 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) + keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) + values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) + output = mx.fast.scaled_dot_product_attention( queries, keys,