From 62f297b51d10014fadff663880b41393d943bb24 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 2 Nov 2024 21:25:46 -0700 Subject: [PATCH] Sdpa fix (#1558) --- .../metal/kernels/scaled_dot_product_attention.metal | 1 + mlx/backend/metal/kernels/sdpa_vector.h | 3 ++- mlx/backend/metal/scaled_dot_product_attention.cpp | 8 +++++--- python/tests/test_fast_sdpa.py | 9 +++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 27fbb765b..478bf2207 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -936,6 +936,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); const constant int& gqa_factor, \ const constant int& N, \ const constant size_t& k_stride, \ + const constant size_t& v_stride, \ const constant float& scale, \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 4d4a9180b..5ef316811 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -13,6 +13,7 @@ template const constant int& gqa_factor, const constant int& N, const constant size_t& k_stride, + const constant size_t& v_stride, const constant float& scale, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -38,7 +39,7 @@ template const int kv_head_idx = head_idx / gqa_factor; queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; - values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 7a3fc03ba..54ec91a4c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -162,7 +162,8 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t stride = k.strides()[1]; + size_t k_stride = k.strides()[1]; + size_t v_stride = v.strides()[1]; MTL::Size group_dims(1024, 1, 1); MTL::Size grid_dims(1, B, 1); @@ -178,8 +179,9 @@ void sdpa_vector( compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(&gqa_factor, sizeof(int), 4); compute_encoder->setBytes(&N, sizeof(int), 5); - compute_encoder->setBytes(&stride, sizeof(size_t), 6); - compute_encoder->setBytes(&scale, sizeof(float), 7); + compute_encoder->setBytes(&k_stride, sizeof(size_t), 6); + compute_encoder->setBytes(&v_stride, sizeof(size_t), 7); + compute_encoder->setBytes(&scale, sizeof(float), 8); // Launch compute_encoder.dispatchThreadgroups(grid_dims, group_dims); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 13b316bd1..c736abe93 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -167,6 +167,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) + q = mx.random.normal(shape=(1, 32, 1, Dk)) + k = mx.random.normal(shape=(1, 32, 32, Dk)) + v = mx.random.normal(shape=(1, 32, 128, Dk)) + + atol = 1e-6 + y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale) + y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale) + self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + if __name__ == "__main__": unittest.main(failfast=True)