Update routing

This commit is contained in:
Jagrit Digani
2025-08-06 15:01:15 -07:00
parent f81edd184f
commit c66b76a8c8

View File

@@ -956,7 +956,19 @@ bool ScaledDotProductAttention::use_fallback(
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
return has_arr_mask || !sdpa_supported_head_dim;
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
auto& cu_device = cu::device(s.device);
const bool supported_matrix_config = query_sequence_length > 4 &&
cu_device.compute_capability_major() >= 8 &&
(q.dtype() == float16 || q.dtype() == bfloat16);
const bool supported_config =
(supported_matrix_config || supported_vector_config);
return has_arr_mask || !supported_config;
}
void ScaledDotProductAttention::eval_gpu(
@@ -993,7 +1005,7 @@ void ScaledDotProductAttention::eval_gpu(
};
// We are in vector mode ie single query
if (q_pre.shape(2) <= 1) {
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
@@ -1056,8 +1068,8 @@ void ScaledDotProductAttention::eval_gpu(
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
// return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
// return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode