fix copies in sdpa (#2563)

This commit is contained in:
Awni Hannun
2025-09-02 11:00:36 -07:00
committed by GitHub
parent 04cbb4191c
commit b61a65e313
2 changed files with 12 additions and 1 deletions

View File

@@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(3); copies.reserve(inputs.size());
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {

View File

@@ -619,6 +619,17 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_noncontiguous_inputs(self):
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
mx.random.seed(0)
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)
k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_promote_mask(self): def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16) mask = mx.array(2.0, mx.bfloat16)
D = 64 D = 64