mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add base cudnn attention support
This commit is contained in:
		| @@ -39,6 +39,7 @@ target_sources( | |||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu | ||||||
|  |           ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu | ||||||
|   | |||||||
| @@ -4,19 +4,6 @@ | |||||||
| #include "mlx/fast_primitives.h" | #include "mlx/fast_primitives.h" | ||||||
| #include "mlx/primitives.h" | #include "mlx/primitives.h" | ||||||
|  |  | ||||||
| namespace mlx::core { |  | ||||||
|  |  | ||||||
| bool fast::ScaledDotProductAttention::use_fallback( |  | ||||||
|     const array& q, |  | ||||||
|     const array& k, |  | ||||||
|     const array& v, |  | ||||||
|     bool has_mask, |  | ||||||
|     bool has_arr_mask, |  | ||||||
|     bool do_causal, |  | ||||||
|     Stream s) { |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #define NO_GPU_MULTI(func)                                             \ | #define NO_GPU_MULTI(func)                                             \ | ||||||
|   void func::eval_gpu(                                                 \ |   void func::eval_gpu(                                                 \ | ||||||
|       const std::vector<array>& inputs, std::vector<array>& outputs) { \ |       const std::vector<array>& inputs, std::vector<array>& outputs) { \ | ||||||
| @@ -53,7 +40,6 @@ NO_GPU_MULTI(Eig) | |||||||
| NO_GPU_MULTI(Eigh) | NO_GPU_MULTI(Eigh) | ||||||
|  |  | ||||||
| namespace fast { | namespace fast { | ||||||
| NO_GPU(ScaledDotProductAttention) |  | ||||||
| NO_GPU_MULTI(CustomKernel) | NO_GPU_MULTI(CustomKernel) | ||||||
| } // namespace fast | } // namespace fast | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani