From 836f019d3bc8a1c6d0e7ab0e7a5f833d9df4db95 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Sep 2025 12:52:56 -0700 Subject: [PATCH] fix matrix sdpa --- .../metal/kernels/steel/attn/kernels/steel_attention.h | 7 +++---- python/tests/test_fast_sdpa.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index be5067b88..7397039b5 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -171,7 +171,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - TransformScale ts(static_cast(params->scale * 1.44269504089)); + TransformScale ts(static_cast(params->scale * M_LOG2E_F)); // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -234,11 +234,10 @@ template < max_score[i] = Limits::finite_min; } - // TODO condition here is wrong if (has_sinks) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = static_cast(sinks[tidl.y]); + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); sum_score[i] = 1; } } @@ -361,7 +360,7 @@ template < Stile.frag_at(i, j)[jj] = mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; } else { - Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); } } } diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 80377e09c..52ecc9be0 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase): mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks) for T_kv in [128, 4096]: - for T_q in [1]: # , 128]: + for T_q in [1, 128]: for N_kv in [2, 8]: q = mx.random.normal(shape=(B, N_q, T_q, D)) k = mx.random.normal(shape=(B, N_kv, T_kv, D))