mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow different value dimensions in sdpa_vector (#1811)
This commit is contained in:
committed by
GitHub
parent
b7c9f1d38f
commit
f5cc1eea72
19
mlx/fast.cpp
19
mlx/fast.cpp
@@ -684,23 +684,20 @@ array scaled_dot_product_attention(
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
|
||||
bool implementation_supports_use_case = query_head_dim == value_head_dim;
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 80;
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool supported_mask = !mask || (mask->dtype() == bool_);
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
supported_mask && sdpa_vector_supported_head_dim &&
|
||||
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
implementation_supports_use_case &=
|
||||
const bool implementation_supports_use_case =
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
|
||||
Reference in New Issue
Block a user