diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1a394e580..2e0be273b 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -38,6 +38,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 0451c9e54..1b4a99c08 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -42,17 +42,6 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { }); } -bool fast::ScaledDotProductAttention::use_fallback( - const array& q, - const array& k, - const array& v, - bool has_mask, - bool has_arr_mask, - bool do_causal, - Stream s) { - return true; -} - #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -89,7 +78,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(CustomKernel) } // namespace fast