mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add float mask to sdpa vector (#2068)
This commit is contained in:
committed by
GitHub
parent
68d1b3256b
commit
c4189a38e4
@@ -739,8 +739,6 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask =
|
||||
!has_mask || has_bool_mask || do_causal;
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
@@ -749,8 +747,7 @@ array scaled_dot_product_attention(
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool implementation_supports_use_case =
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
Reference in New Issue
Block a user