SDPA fixes

This commit is contained in:
paramthakkar123 2025-04-19 08:15:56 +05:30
parent 5f04c0f818
commit 8ac0a20250
2 changed files with 2 additions and 3 deletions

View File

@ -734,8 +734,7 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
(query_head_dim % 32 == 0 || query_head_dim % 2 == 0);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);

View File

@ -5108,4 +5108,4 @@ bool Hadamard::is_equivalent(const Primitive& other) const {
return scale_ == h_other.scale_;
}
} // namespace mlx::core
} // namespace mlx::core