mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add stricter condition to matrix sdpa
This commit is contained in:
		| @@ -499,7 +499,8 @@ void sdpa_vector_1pass_fallback( | |||||||
|       dispatch_headdim(params.D, [&](auto headdim) { |       dispatch_headdim(params.D, [&](auto headdim) { | ||||||
|         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; |         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||||
|  |  | ||||||
|         auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>; |         auto kernel = | ||||||
|  |             cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>; | ||||||
|         encoder.add_kernel_node( |         encoder.add_kernel_node( | ||||||
|             kernel, |             kernel, | ||||||
|             grid_dim, |             grid_dim, | ||||||
| @@ -569,8 +570,8 @@ void sdpa_vector_2pass_fallback( | |||||||
|         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; |         using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||||
|  |  | ||||||
|         { |         { | ||||||
|           auto kernel = |           auto kernel = cu:: | ||||||
|               cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>; |               kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>; | ||||||
|  |  | ||||||
|           encoder.set_input_array(q); |           encoder.set_input_array(q); | ||||||
|           encoder.set_input_array(k); |           encoder.set_input_array(k); | ||||||
| @@ -597,8 +598,8 @@ void sdpa_vector_2pass_fallback( | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         { |         { | ||||||
|           auto kernel = |           auto kernel = cu:: | ||||||
|               cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>; |               kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>; | ||||||
|  |  | ||||||
|           encoder.set_input_array(intermediate); |           encoder.set_input_array(intermediate); | ||||||
|           encoder.set_input_array(sums); |           encoder.set_input_array(sums); | ||||||
| @@ -963,6 +964,7 @@ bool ScaledDotProductAttention::use_fallback( | |||||||
|  |  | ||||||
|   const bool supported_matrix_config = query_sequence_length > 4 && |   const bool supported_matrix_config = query_sequence_length > 4 && | ||||||
|       cu_device.compute_capability_major() >= 8 && |       cu_device.compute_capability_major() >= 8 && | ||||||
|  |       query_sequence_length == key_sequence_length && | ||||||
|       (q.dtype() == float16 || q.dtype() == bfloat16); |       (q.dtype() == float16 || q.dtype() == bfloat16); | ||||||
|  |  | ||||||
|   const bool supported_config = |   const bool supported_config = | ||||||
| @@ -1105,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu( | |||||||
|  |  | ||||||
| } // namespace fast | } // namespace fast | ||||||
|  |  | ||||||
| } // namespace mlx::core | } // namespace mlx::core | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos