From e613d0eaf017ea37071542c95d90ec8fda698a46 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 10:59:04 -0800 Subject: [PATCH] SDPA support for small batch (over sequence) queries (#1922) * batch query sdpa * batch sdpa for query --- mlx/backend/metal/kernels/sdpa_vector.h | 59 ++++++++++----- mlx/backend/metal/rope.cpp | 12 ++- .../metal/scaled_dot_product_attention.cpp | 75 +++++++++++++------ mlx/fast.cpp | 3 +- python/tests/test_fast_sdpa.py | 55 ++++++++++++++ 5 files changed, 159 insertions(+), 45 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index f5fe88f4a..1c3d23fc4 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -5,6 +5,7 @@ using namespace metal; constant bool has_mask [[function_constant(20)]]; +constant bool query_transposed [[function_constant(21)]]; template [[kernel]] void sdpa_vector( @@ -18,9 +19,11 @@ template const constant size_t& v_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], + const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; @@ -41,15 +44,21 @@ template threadgroup U sum_exp_scores[BN]; // Adjust positions - const int head_idx = tid.y; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * qk_per_thread; + const int o_offset = tpg.x * q_seq_idx + head_idx; + const int q_offset = + query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread; values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread; if (has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; + mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; } - out += head_idx * V + simd_gid * v_per_thread; + + out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { @@ -95,7 +104,7 @@ template keys += inner_k_stride; values += inner_v_stride; if (has_mask) { - mask += BN * mask_seq_stride; + mask += BN * mask_kv_seq_stride; } } @@ -142,9 +151,11 @@ template const constant size_t& v_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_seq_stride [[function_constant(has_mask)]], + const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], + const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_head_stride [[function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 8; @@ -167,20 +178,26 @@ template // Adjust positions const int block_idx = tid.z; - const int head_idx = tid.y; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int o_offset = tpg.x * q_seq_idx + head_idx; + const int q_offset = + query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * qk_per_thread; + + queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + simd_lid * qk_per_thread; values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V + simd_lid * v_per_thread; - out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread; + out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_seq_stride; + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; } - sums += head_idx * blocks + block_idx; - maxs += head_idx * blocks + block_idx; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { @@ -226,7 +243,7 @@ template keys += blocks * inner_k_stride; values += blocks * inner_v_stride; if (has_mask) { - mask += BN * blocks * mask_seq_stride; + mask += BN * blocks * mask_kv_seq_stride; } } @@ -275,6 +292,7 @@ template const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; @@ -288,11 +306,14 @@ template threadgroup U outputs[BN * BD]; // Adjust positions - const int head_idx = tid.y; - partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; - sums += head_idx * blocks; - maxs += head_idx * blocks; - out += head_idx * D + simd_gid * elem_per_thread; + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int n_heads = tpg.x; + const int q_offset = n_heads * q_seq_idx + head_idx; + partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * D + simd_gid * elem_per_thread; // First everybody reads the max and sum_exp U max_score = maxs[simd_lid]; diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 1ca3597e5..c6da9278a 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -25,6 +25,10 @@ void RoPE::eval_gpu( size_t out_strides[3]; bool donated = false; int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } size_t mat_size = in.shape(-2) * in.shape(-1); if (dims_ < in.shape(-1)) { donated = true; @@ -44,12 +48,12 @@ void RoPE::eval_gpu( strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; - } else if (ndim == 3) { + } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs out.set_data(allocator::malloc_or_wait(out.nbytes())); - strides[0] = in.strides()[0]; - strides[1] = in.strides()[1]; - strides[2] = in.strides()[2]; + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 47bf7f22a..a349fd031 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -134,14 +134,17 @@ void sdpa_vector( size_t k_stride = k.strides()[1]; size_t v_stride = v.strides()[1]; MTL::Size group_dims(1024, 1, 1); - MTL::Size grid_dims(1, B, 1); + MTL::Size grid_dims(B, q.shape(2), 1); bool has_mask = mask.has_value(); + bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, }; std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -161,10 +164,14 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 9); - int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; - int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; - compute_encoder.set_bytes(seq_stride, 10); - compute_encoder.set_bytes(head_stride, 11); + auto nd = m.ndim(); + int32_t kv_seq_stride = + nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; + int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; + int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + compute_encoder.set_bytes(kv_seq_stride, 10); + compute_encoder.set_bytes(q_seq_stride, 11); + compute_encoder.set_bytes(head_stride, 12); } // Launch @@ -198,7 +205,7 @@ void sdpa_vector_2pass( auto k_stride = k.strides()[1]; auto v_stride = v.strides()[1]; MTL::Size group_dims(8 * 32, 1, 1); - MTL::Size grid_dims(1, B, blocks); + MTL::Size grid_dims(B, q.shape(2), blocks); // Allocate the intermediates Shape intermediate_shape; @@ -219,11 +226,14 @@ void sdpa_vector_2pass( d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); + bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, }; std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -246,10 +256,14 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11); - int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0; - int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0; - compute_encoder.set_bytes(seq_stride, 12); - compute_encoder.set_bytes(head_stride, 13); + auto nd = m.ndim(); + int32_t kv_seq_stride = + nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; + int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; + int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + compute_encoder.set_bytes(kv_seq_stride, 12); + compute_encoder.set_bytes(q_seq_stride, 13); + compute_encoder.set_bytes(head_stride, 14); } // Launch @@ -274,7 +288,7 @@ void sdpa_vector_2pass( // Launch group_dims = MTL::Size(1024, 1, 1); - grid_dims = MTL::Size(1, B, 1); + grid_dims = MTL::Size(B, q.shape(2), 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -301,16 +315,23 @@ void ScaledDotProductAttention::eval_gpu( if (!predicate(arr)) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); + copies.push_back(std::move(arr_copy)); return copies.back(); } else { return arr; } }; - // Checks if arr is fully row contiguous - auto is_contiguous = [](const array& arr) { - return arr.flags().row_contiguous; + // Checks if arr is row contiguous or the sequence and head dimension are + // transposed + auto is_contiguous_or_head_seq_transposed = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && + (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); }; // Returns true if the array is row contiguous except the sequence length @@ -328,18 +349,30 @@ void ScaledDotProductAttention::eval_gpu( }; // We are in vector mode ie single query - if (q_pre.shape(2) == 1) { - const auto& q = copy_unless(is_contiguous, q_pre); - // 1, heads, seq_len, head_dim - // mask [1, query_heads, 1, seq_len] + if (q_pre.shape(2) <= 8) { + const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); // Donate the query if possible - if (q.is_donatable() && q.size() == o.size()) { + if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && + q.size() == o.size()) { o.move_shared_buffer(q); } else { - o.set_data(allocator::malloc_or_wait(o.nbytes())); + if (o.shape(2) == 1) { + o.set_data(allocator::malloc_or_wait(o.nbytes())); + } else { + auto strides = o.strides(); + strides[2] = o.shape(1) * o.shape(3); + strides[1] = o.shape(3); + auto flags = q.flags(); + flags.row_contiguous = q.shape(1) == 1; + o.set_data( + allocator::malloc_or_wait(o.nbytes()), + o.size(), + std::move(strides), + flags); + } } auto mask = diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 82e1f0569..8ef6a8469 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -715,7 +715,8 @@ array scaled_dot_product_attention( const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && sdpa_full_supported_head_dim && stream.device == Device::gpu; - const bool supports_sdpa_vector = query_sequence_length == 1 && + const bool supports_sdpa_vector = (query_sequence_length <= 8) && + (query_sequence_length <= k.shape(-2)) && (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && stream.device == Device::gpu; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 348ba4c88..9baee4fb1 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -262,6 +262,61 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_fast_sdpa_few_query(self): + D = 64 + L = 43 + Lq = 4 + Nq = 8 + Nkv = 1 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D)) + q = q.swapaxes(1, 2) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + return + L = 4096 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + @unittest.skip("Different head and value dims is not enabled") def test_fast_sdpa_vector_value_dims(self): D = 192