mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Fix PLaMo model to support Grouped Query Attention (#1037)
This commit is contained in:
parent
d8611dd69f
commit
7612c646f3
@ -89,6 +89,9 @@ class Attention(nn.Module):
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
||||
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
|
Loading…
Reference in New Issue
Block a user