From 8ac0a2025023bfeb0e623860ab9d041d51ecac63 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sat, 19 Apr 2025 08:15:56 +0530 Subject: [PATCH] SDPA fixes --- mlx/fast.cpp | 3 +-- mlx/primitives.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 77210f713..e95263683 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -734,8 +734,7 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); + (query_head_dim % 32 == 0 || query_head_dim % 2 == 0); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 590af60f6..d6de1005d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -5108,4 +5108,4 @@ bool Hadamard::is_equivalent(const Primitive& other) const { return scale_ == h_other.scale_; } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file