From 657f466402b554ff6074f52402782dd772dae90f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 9 Jan 2025 13:11:55 -0800 Subject: [PATCH] use sdpa and exportable functions in transformer multi head attention (#1760) --- python/mlx/nn/layers/transformer.py | 22 ++++++++-------------- python/tests/test_nn.py | 6 ++++++ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index f3df7986c..d856f2554 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -82,21 +82,15 @@ class MultiHeadAttention(Module): values = self.value_proj(values) num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - - # Dimensions are [batch x num heads x sequence x hidden dim] + queries = mx.unflatten(queries, -1, (num_heads, -1)).transpose(0, 2, 1, 3) + keys = mx.unflatten(keys, -1, (num_heads, -1)).transpose(0, 2, 1, 3) + values = mx.unflatten(values, -1, (num_heads, -1)).transpose(0, 2, 1, 3) scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys - if mask is not None: - scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).flatten(-2, -1) + return self.out_proj(output) @staticmethod def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7ca8ba272..5aa230175 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1835,6 +1835,12 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertFalse(mx.any(mx.isnan(mask))) self.assertTrue(mask[0, -1].item() < 0) + def test_attention(self): + attn = nn.MultiHeadAttention(32, 4) + x = mx.random.normal(shape=(2, 5, 32)) + out = attn(x, x, x) + self.assertEqual(out.shape, x.shape) + if __name__ == "__main__": unittest.main()