mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 20:58:08 +08:00
fix copies in sdpa (#2563)
This commit is contained in:
@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_noncontiguous_inputs(self):
|
||||
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
|
||||
mx.random.seed(0)
|
||||
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
|
||||
|
||||
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
|
Reference in New Issue
Block a user