mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 17:58:54 +08:00
Change gqa to use repeat instead of concatenate (#443)
This commit is contained in:
committed by
GitHub
parent
06ddb8414d
commit
f71e965d57
@@ -86,12 +86,9 @@ class PhiAttention(nn.Module):
|
||||
B, L, self.num_key_value_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.num_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
|
||||
Reference in New Issue
Block a user