mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
use sdpa and exportable functions in transformer multi head attention (#1760)
This commit is contained in:
parent
c7b0300af5
commit
657f466402
@ -82,21 +82,15 @@ class MultiHeadAttention(Module):
|
|||||||
values = self.value_proj(values)
|
values = self.value_proj(values)
|
||||||
|
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
B, L, D = queries.shape
|
queries = mx.unflatten(queries, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
|
||||||
_, S, _ = keys.shape
|
keys = mx.unflatten(keys, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
values = mx.unflatten(values, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
|
||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
|
||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
scores = (queries * scale) @ keys
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
if mask is not None:
|
queries, keys, values, scale=scale, mask=mask
|
||||||
scores = scores + mask.astype(scores.dtype)
|
)
|
||||||
scores = mx.softmax(scores, axis=-1)
|
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
return self.out_proj(output)
|
||||||
|
|
||||||
return self.out_proj(values_hat)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||||
|
@ -1835,6 +1835,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertFalse(mx.any(mx.isnan(mask)))
|
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||||
self.assertTrue(mask[0, -1].item() < 0)
|
self.assertTrue(mask[0, -1].item() < 0)
|
||||||
|
|
||||||
|
def test_attention(self):
|
||||||
|
attn = nn.MultiHeadAttention(32, 4)
|
||||||
|
x = mx.random.normal(shape=(2, 5, 32))
|
||||||
|
out = attn(x, x, x)
|
||||||
|
self.assertEqual(out.shape, x.shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user