From af705590ac9335105a5a026de4fc68ee6e747a9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 13:13:03 -0700 Subject: [PATCH] fix batched vector sdpa (#2152) --- mlx/backend/metal/kernels/sdpa_vector.h | 12 +- .../metal/scaled_dot_product_attention.cpp | 103 ++++++++++-------- python/tests/test_fast_sdpa.py | 40 +++++++ 3 files changed, 105 insertions(+), 50 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..8258e9c14 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01..d75e6d87d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -154,9 +154,9 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); @@ -199,11 +199,10 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); @@ -238,9 +237,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -302,11 +302,10 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); @@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks if arr is row contiguous or the sequence and head dimension are - // transposed - auto is_contiguous_or_head_seq_transposed = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && - (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; @@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible - if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && - q.size() == o.size()) { + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - if (o.shape(2) == 1) { - o.set_data(allocator::malloc(o.nbytes())); - } else { - auto strides = o.strides(); - strides[2] = o.shape(1) * o.shape(3); - strides[1] = o.shape(3); - auto flags = q.flags(); - flags.row_contiguous = q.shape(1) == 1; - o.set_data( - allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); - } + o.set_data(allocator::malloc(o.nbytes())); } - auto mask = - inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index d35a2b1da..8f55d41e3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_vector_batched(self): + D = 64 + q = mx.random.normal(shape=(2, 1, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) + v = mx.random.normal(shape=(2, 2, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + ref = mlx_ref_attn(q, k, v, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + class TestSDPA(mlx_tests.MLXTestCase): @property