mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add stricter condition to matrix sdpa
This commit is contained in:
		| @@ -499,7 +499,8 @@ void sdpa_vector_1pass_fallback( | ||||
|       dispatch_headdim(params.D, [&](auto headdim) { | ||||
|         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||
|  | ||||
|         auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>; | ||||
|         auto kernel = | ||||
|             cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>; | ||||
|         encoder.add_kernel_node( | ||||
|             kernel, | ||||
|             grid_dim, | ||||
| @@ -569,8 +570,8 @@ void sdpa_vector_2pass_fallback( | ||||
|         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||
|  | ||||
|         { | ||||
|           auto kernel = | ||||
|               cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>; | ||||
|           auto kernel = cu:: | ||||
|               kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>; | ||||
|  | ||||
|           encoder.set_input_array(q); | ||||
|           encoder.set_input_array(k); | ||||
| @@ -597,8 +598,8 @@ void sdpa_vector_2pass_fallback( | ||||
|         } | ||||
|  | ||||
|         { | ||||
|           auto kernel = | ||||
|               cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>; | ||||
|           auto kernel = cu:: | ||||
|               kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>; | ||||
|  | ||||
|           encoder.set_input_array(intermediate); | ||||
|           encoder.set_input_array(sums); | ||||
| @@ -963,6 +964,7 @@ bool ScaledDotProductAttention::use_fallback( | ||||
|  | ||||
|   const bool supported_matrix_config = query_sequence_length > 4 && | ||||
|       cu_device.compute_capability_major() >= 8 && | ||||
|       query_sequence_length == key_sequence_length && | ||||
|       (q.dtype() == float16 || q.dtype() == bfloat16); | ||||
|  | ||||
|   const bool supported_config = | ||||
| @@ -1105,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|  | ||||
| } // namespace fast | ||||
|  | ||||
| } // namespace mlx::core | ||||
| } // namespace mlx::core | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos