[mlx-lm] Use sdpa in llama / mistral model (#515)

* use sdpa

* update a few more models

* version

* fix stablelm type
This commit is contained in:
Awni Hannun
2024-03-07 17:41:23 -08:00
committed by GitHub
parent 7cdd1b69ac
commit 8b05bb6d18
7 changed files with 25 additions and 59 deletions

View File

@@ -49,8 +49,6 @@ class Attention(nn.Module):
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.repeats = n_heads // n_kv_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
@@ -79,10 +77,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
@@ -93,11 +87,11 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)