diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 99c27de252..0d05a69328 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. // clang-format off -// #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/attn/attn.h" @@ -26,7 +25,7 @@ instantiate_attn(iname, itype, 32, 32, 64, 4, 1) instantiate_attn_shapes_helper(float16, half); -// instantiate_attn_shapes_helper(bfloat16, bfloat16_t); +instantiate_attn_shapes_helper(bfloat16, bfloat16_t); instantiate_attn_shapes_helper(float32, float); // clang-format on diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 47f17745cc..56c16b3661 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -648,7 +648,7 @@ array scaled_dot_product_attention( const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim && - final_type != bfloat16 && stream.device == Device::gpu; + stream.device == Device::gpu; const bool supports_sdpa_vector = query_sequence_length == 1 && !mask.has_value() && sdpa_vector_supported_head_dim &&