mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add base cudnn attention support
This commit is contained in:
		| @@ -39,6 +39,7 @@ target_sources( | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu | ||||
|   | ||||
| @@ -4,19 +4,6 @@ | ||||
| #include "mlx/fast_primitives.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| bool fast::ScaledDotProductAttention::use_fallback( | ||||
|     const array& q, | ||||
|     const array& k, | ||||
|     const array& v, | ||||
|     bool has_mask, | ||||
|     bool has_arr_mask, | ||||
|     bool do_causal, | ||||
|     Stream s) { | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| #define NO_GPU_MULTI(func)                                             \ | ||||
|   void func::eval_gpu(                                                 \ | ||||
|       const std::vector<array>& inputs, std::vector<array>& outputs) { \ | ||||
| @@ -53,7 +40,6 @@ NO_GPU_MULTI(Eig) | ||||
| NO_GPU_MULTI(Eigh) | ||||
|  | ||||
| namespace fast { | ||||
| NO_GPU(ScaledDotProductAttention) | ||||
| NO_GPU_MULTI(CustomKernel) | ||||
| } // namespace fast | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani