mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Update attention tests to show diff, disable array masks (#1978)
This commit is contained in:
@@ -748,8 +748,8 @@ array scaled_dot_product_attention(
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
const bool sdpa_full_supported_mask =
|
||||
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||
|
Reference in New Issue
Block a user