mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Update routing
This commit is contained in:
		| @@ -956,7 +956,19 @@ bool ScaledDotProductAttention::use_fallback( | ||||
|   const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && | ||||
|       (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); | ||||
|  | ||||
|   return has_arr_mask || !sdpa_supported_head_dim; | ||||
|   const bool supported_vector_config = | ||||
|       sdpa_supported_head_dim && query_sequence_length < 4; | ||||
|  | ||||
|   auto& cu_device = cu::device(s.device); | ||||
|  | ||||
|   const bool supported_matrix_config = query_sequence_length > 4 && | ||||
|       cu_device.compute_capability_major() >= 8 && | ||||
|       (q.dtype() == float16 || q.dtype() == bfloat16); | ||||
|  | ||||
|   const bool supported_config = | ||||
|       (supported_matrix_config || supported_vector_config); | ||||
|  | ||||
|   return has_arr_mask || !supported_config; | ||||
| } | ||||
|  | ||||
| void ScaledDotProductAttention::eval_gpu( | ||||
| @@ -993,7 +1005,7 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|   }; | ||||
|  | ||||
|   // We are in vector mode ie single query | ||||
|   if (q_pre.shape(2) <= 1) { | ||||
|   if (q_pre.shape(2) < 4) { | ||||
|     auto q_copy_unless = [](const array& arr) { | ||||
|       if (arr.flags().row_contiguous) { | ||||
|         return true; | ||||
| @@ -1056,8 +1068,8 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|           flags); | ||||
|     } | ||||
|  | ||||
|     return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); | ||||
|     // return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); | ||||
|     // return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); | ||||
|     return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); | ||||
|   } | ||||
|  | ||||
|   // Full attention mode | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani