Change gqa to use repeat instead of concatenate (#443)

This commit is contained in:
Angelos Katharopoulos
2024-02-14 17:40:11 -08:00
committed by GitHub
parent 06ddb8414d
commit f71e965d57
8 changed files with 16 additions and 40 deletions

View File

@@ -86,12 +86,9 @@ class PhiAttention(nn.Module):
B, L, self.num_key_value_heads, self.head_dim
).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None: