Fix PLaMo model to support Grouped Query Attention (#1037)

This commit is contained in:
Shunta Saito 2024-10-13 07:26:50 +09:00 committed by GitHub
parent d8611dd69f
commit 7612c646f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,6 +89,9 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention(
queries,
keys,