mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa * batch sdpa for query
This commit is contained in:
@@ -715,7 +715,8 @@ array scaled_dot_product_attention(
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= k.shape(-2)) &&
|
||||
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user