use sdpa and exportable functions in transformer multi head attention (#1760)

This commit is contained in:
Awni Hannun
2025-01-09 13:11:55 -08:00
committed by GitHub
parent c7b0300af5
commit 657f466402
2 changed files with 14 additions and 14 deletions

View File

@@ -1835,6 +1835,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertFalse(mx.any(mx.isnan(mask)))
self.assertTrue(mask[0, -1].item() < 0)
def test_attention(self):
attn = nn.MultiHeadAttention(32, 4)
x = mx.random.normal(shape=(2, 5, 32))
out = attn(x, x, x)
self.assertEqual(out.shape, x.shape)
if __name__ == "__main__":
unittest.main()