From 0fe25eb58822ac40ac8d377ffe4603b00629b23b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Sep 2025 12:12:09 -0700 Subject: [PATCH] fix 2 pass --- mlx/backend/metal/kernels/sdpa_vector.h | 2 +- python/tests/test_fast_sdpa.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 6b2926cb5..159d268a6 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -263,7 +263,7 @@ template U max_score = -INFINITY; U sum_exp_score = 0; - if (has_sinks && simd_gid == 0) { + if (has_sinks && block_idx == 0 && simd_gid == 0) { int q_head_idx = q_batch_head_idx % num_q_heads; max_score = static_cast(sinks[q_head_idx]); sum_exp_score = 1; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 6d2ae908f..80377e09c 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -730,18 +730,19 @@ class TestSDPA(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks) - for T_q in [1, 128]: - for N_kv in [2, 8]: - q = mx.random.normal(shape=(B, N_q, T_q, D)) - k = mx.random.normal(shape=(B, N_kv, T_kv, D)) - v = mx.random.normal(shape=(B, N_kv, T_kv, D)) - sinks = 10 * mx.random.normal(shape=(N_q,)) + for T_kv in [128, 4096]: + for T_q in [1]: # , 128]: + for N_kv in [2, 8]: + q = mx.random.normal(shape=(B, N_q, T_q, D)) + k = mx.random.normal(shape=(B, N_kv, T_kv, D)) + v = mx.random.normal(shape=(B, N_kv, T_kv, D)) + sinks = 10 * mx.random.normal(shape=(N_q,)) - expected = mlx_ref_attn(q, k, v, scale, sinks=sinks) - out = mx.fast.scaled_dot_product_attention( - q, k, v, scale=scale, sinks=sinks - ) - self.assertTrue(mx.allclose(out, expected, atol=1e-5)) + expected = mlx_ref_attn(q, k, v, scale, sinks=sinks) + out = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, sinks=sinks + ) + self.assertTrue(mx.allclose(out, expected, atol=1e-5)) if __name__ == "__main__":