mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
SDPA fixes
This commit is contained in:
parent
5f04c0f818
commit
8ac0a20250
@ -734,8 +734,7 @@ array scaled_dot_product_attention(
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
(query_head_dim % 32 == 0 || query_head_dim % 2 == 0);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
|
@ -5108,4 +5108,4 @@ bool Hadamard::is_equivalent(const Primitive& other) const {
|
||||
return scale_ == h_other.scale_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user