diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 1c3d23fc4..88a109b88 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -15,8 +15,10 @@ template device T* out [[buffer(3)]], const constant int& gqa_factor, const constant int& N, - const constant size_t& k_stride, - const constant size_t& v_stride, + const constant size_t& k_head_stride, + const constant size_t& k_seq_stride, + const constant size_t& v_head_stride, + const constant size_t& v_seq_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], @@ -30,8 +32,8 @@ template constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; - constexpr int inner_k_stride = BN * D; - constexpr int inner_v_stride = BN * V; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); typedef float U; @@ -51,8 +53,10 @@ template const int q_offset = query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; queries += q_offset * D + simd_lid * qk_per_thread; - keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread; - values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread; + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; @@ -147,8 +151,10 @@ template device float* maxs [[buffer(5)]], const constant int& gqa_factor, const constant int& N, - const constant size_t& k_stride, - const constant size_t& v_stride, + const constant size_t& k_head_stride, + const constant size_t& k_seq_stride, + const constant size_t& v_head_stride, + const constant size_t& v_seq_stride, const constant float& scale, const device bool* mask [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], @@ -162,8 +168,8 @@ template constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; - constexpr int inner_k_stride = BN * D; - constexpr int inner_v_stride = BN * V; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); constexpr int blocks = 32; typedef float U; @@ -186,10 +192,10 @@ template const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; - keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + - simd_lid * qk_per_thread; - values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V + - simd_lid * v_per_thread; + keys += kv_head_idx * k_head_stride + + (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + + (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 938d7afa6..7fbd63022 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -131,8 +131,11 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_stride = k.strides()[1]; - size_t v_stride = v.strides()[1]; + size_t k_head_stride = k.strides()[1]; + size_t k_seq_stride = k.strides()[2]; + size_t v_head_stride = v.strides()[1]; + size_t v_seq_stride = v.strides()[2]; + MTL::Size group_dims(1024, 1, 1); MTL::Size grid_dims(B, q.shape(2), 1); @@ -158,20 +161,23 @@ void sdpa_vector( compute_encoder.set_output_array(out, 3); compute_encoder.set_bytes(gqa_factor, 4); compute_encoder.set_bytes(N, 5); - compute_encoder.set_bytes(k_stride, 6); - compute_encoder.set_bytes(v_stride, 7); - compute_encoder.set_bytes(scale, 8); + compute_encoder.set_bytes(k_head_stride, 6); + compute_encoder.set_bytes(k_seq_stride, 7); + compute_encoder.set_bytes(v_head_stride, 8); + compute_encoder.set_bytes(v_seq_stride, 9); + + compute_encoder.set_bytes(scale, 10); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 9); + compute_encoder.set_input_array(m, 11); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 10); - compute_encoder.set_bytes(q_seq_stride, 11); - compute_encoder.set_bytes(head_stride, 12); + compute_encoder.set_bytes(kv_seq_stride, 12); + compute_encoder.set_bytes(q_seq_stride, 13); + compute_encoder.set_bytes(head_stride, 14); } // Launch @@ -202,8 +208,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - auto k_stride = k.strides()[1]; - auto v_stride = v.strides()[1]; + size_t k_head_stride = k.strides()[1]; + size_t k_seq_stride = k.strides()[2]; + size_t v_head_stride = v.strides()[1]; + size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -250,20 +258,22 @@ void sdpa_vector_2pass( compute_encoder.set_output_array(maxs, 5); compute_encoder.set_bytes(gqa_factor, 6); compute_encoder.set_bytes(N, 7); - compute_encoder.set_bytes(k_stride, 8); - compute_encoder.set_bytes(v_stride, 9); - compute_encoder.set_bytes(scale, 10); + compute_encoder.set_bytes(k_head_stride, 8); + compute_encoder.set_bytes(k_seq_stride, 9); + compute_encoder.set_bytes(v_head_stride, 10); + compute_encoder.set_bytes(v_seq_stride, 11); + compute_encoder.set_bytes(scale, 12); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 11); + compute_encoder.set_input_array(m, 13); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 12); - compute_encoder.set_bytes(q_seq_stride, 13); - compute_encoder.set_bytes(head_stride, 14); + compute_encoder.set_bytes(kv_seq_stride, 14); + compute_encoder.set_bytes(q_seq_stride, 15); + compute_encoder.set_bytes(head_stride, 16); } // Launch @@ -334,15 +344,6 @@ void ScaledDotProductAttention::eval_gpu( (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); }; - // Returns true if the array is row contiguous except the sequence length - // dimension that can be sliced but with step=1. - auto is_contiguous_except_seq_len = [](const array& arr) { - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return strides[3] == 1 && strides[2] == shape[3] && - strides[0] == strides[1] * shape[1]; - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(3) == 1; @@ -351,8 +352,8 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); - const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); + const auto& k = copy_unless(is_matrix_contiguous, k_pre); + const auto& v = copy_unless(is_matrix_contiguous, v_pre); // Donate the query if possible if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 097eb7dc3..0d172cee4 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -183,9 +183,11 @@ class TestDistributed(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.distributed.all_sum(x) mx.eval(y) + mx.synchronize(mx.default_stream(mx.default_device())) all_sum_only = mx.metal.get_peak_memory() y = mx.distributed.all_sum(x) * scale mx.eval(y) + mx.synchronize(mx.default_stream(mx.default_device())) all_sum_with_binary = mx.metal.get_peak_memory() self.assertEqual(all_sum_only, all_sum_with_binary) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 9baee4fb1..5426ea236 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -171,7 +171,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase): rtol = 1e-2 self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) - q = mx.random.normal(shape=(1, 32, 1, Dk)) k = mx.random.normal(shape=(1, 32, 32, Dk)) v = mx.random.normal(shape=(1, 32, 128, Dk)) @@ -201,6 +200,38 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + def test_fast_sdpa_vector_kv_transposed_head_seq(self): + D = 64 + Nq = 4 + Nkv = 1 + scale = 1.0 + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) + + lengths = [43, 4096] + for L in lengths: + k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D)) + v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D)) + k = k.swapaxes(1, 2) + v = v.swapaxes(1, 2) + masks = [ + mx.array(True), + mx.array([True] * (L - 10) + [False] * 10), + mx.random.uniform(shape=(Nq, 1, L)) > 0.2, + mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + ] + + for m in masks: + ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) + out = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + mask=m, + ) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_fast_sdpa_vector(self): D = 64 L = 43 @@ -292,7 +323,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) - return L = 4096 scale = 1.0 mx.random.seed(0)