Enable bf16

This commit is contained in:
Jagrit Digani 2024-11-21 11:24:17 -08:00
parent 0404037ea6
commit 4640f865cc
2 changed files with 2 additions and 3 deletions

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc.
// clang-format off
// #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
@ -26,7 +25,7 @@
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@ -648,7 +648,7 @@ array scaled_dot_product_attention(
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
final_type != bfloat16 && stream.device == Device::gpu;
stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&