mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 01:06:43 +08:00
Enable bf16
This commit is contained in:
parent
0404037ea6
commit
4640f865cc
@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
// #include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
@ -26,7 +25,7 @@
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
// clang-format on
|
||||
|
@ -648,7 +648,7 @@ array scaled_dot_product_attention(
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
final_type != bfloat16 && stream.device == Device::gpu;
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
|
Loading…
Reference in New Issue
Block a user