From 870208eff52dc9ecde0edb310f5ffb3ab5f67e5a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 15 Jun 2025 21:58:34 -0700 Subject: [PATCH] Start sdpa vector --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 12 ----- .../cuda/scaled_dot_product_attention.cu | 51 +++++++++++++++++++ 3 files changed, 52 insertions(+), 12 deletions(-) create mode 100644 mlx/backend/cuda/scaled_dot_product_attention.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d96bb8812..0b40e08b3 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -33,6 +33,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index c2362bea2..28d79b15b 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -43,17 +43,6 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { }); } -bool fast::ScaledDotProductAttention::use_fallback( - const array& q, - const array& k, - const array& v, - bool has_mask, - bool has_arr_mask, - bool do_causal, - Stream s) { - return true; -} - #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -94,7 +83,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu new file mode 100644 index 000000000..c0384e578 --- /dev/null +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -0,0 +1,51 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/fast_primitives.h" + +namespace mlx::core { + +namespace cu {} // namespace cu + +namespace fast { + +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); + + const bool sdpa_vector_supported_head_dim = + query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); + const bool supports_sdpa_vector = (query_sequence_length <= 1) && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_head_dim; + + return !supports_sdpa_vector; +} + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); +} + +} // namespace fast + +} // namespace mlx::core