mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Make attention faster for a some models (#574)
* make attention faster for a couple models * remove unused generation flags * add comment on lora * include text files as well
This commit is contained in:
@@ -39,8 +39,6 @@ class MixtralAttention(nn.Module):
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
@@ -79,10 +77,6 @@ class MixtralAttention(nn.Module):
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if self.repeats > 1:
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@@ -93,11 +87,10 @@ class MixtralAttention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user