diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 8cbf0d015a..47f17745cc 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -644,7 +644,7 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128; + query_head_dim == 64 || query_head_dim == 80; const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim &&