mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
Update routing
This commit is contained in:
@@ -956,7 +956,19 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
|
||||
return has_arr_mask || !sdpa_supported_head_dim;
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
auto& cu_device = cu::device(s.device);
|
||||
|
||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||
cu_device.compute_capability_major() >= 8 &&
|
||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||
|
||||
const bool supported_config =
|
||||
(supported_matrix_config || supported_vector_config);
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@@ -993,7 +1005,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) <= 1) {
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
@@ -1056,8 +1068,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
// return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
// return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
|
||||
Reference in New Issue
Block a user