Change gqa to use repeat instead of concatenate (#443)

This commit is contained in:
Angelos Katharopoulos
2024-02-14 17:40:11 -08:00
committed by GitHub
parent 06ddb8414d
commit f71e965d57
8 changed files with 16 additions and 40 deletions

View File

@@ -93,11 +93,8 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache