Make attention faster for a some models (#574)

* make attention faster for a couple models

* remove unused generation flags

* add comment on lora

* include text files as well
This commit is contained in:
Awni Hannun
2024-03-14 21:35:54 -07:00
committed by GitHub
parent 3f3741d229
commit e4b19bb9e1
6 changed files with 35 additions and 56 deletions

View File

@@ -39,8 +39,6 @@ class MixtralAttention(nn.Module):
self.num_key_value_heads = args.num_key_value_heads
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
@@ -79,10 +77,6 @@ class MixtralAttention(nn.Module):
0, 2, 1, 3
)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
@@ -93,11 +87,10 @@ class MixtralAttention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)