mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
fix batched vector sdpa (#2152)
This commit is contained in:
@@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_vector_batched(self):
|
||||
D = 64
|
||||
q = mx.random.normal(shape=(2, 1, 3, D))
|
||||
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)
|
||||
v = mx.random.normal(shape=(2, 2, 3, D))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
q = mx.random.normal(shape=(2, 4, 3, D))
|
||||
k = mx.random.normal(shape=(2, 1, 3, D))
|
||||
v = mx.random.normal(shape=(2, 1, 3, D))
|
||||
mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||
ref = mlx_ref_attn(q, k, v, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
class TestSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
|
Reference in New Issue
Block a user