mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Change gqa to use repeat instead of concatenate (#443)
This commit is contained in:

committed by
GitHub

parent
06ddb8414d
commit
f71e965d57
@@ -93,11 +93,8 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
|
Reference in New Issue
Block a user