Update attention tests to show diff, disable array masks (#1978)

This commit is contained in:
Jagrit Digani
2025-03-20 14:25:38 -07:00
committed by GitHub
parent 9adcd1a650
commit b42d13ec84
2 changed files with 6 additions and 7 deletions

View File

@@ -748,8 +748,8 @@ array scaled_dot_product_attention(
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool sdpa_full_supported_mask =
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&