mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Start sdpa vector
This commit is contained in:
parent
bc53f8293f
commit
870208eff5
@ -33,6 +33,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
|
@ -43,17 +43,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
@ -94,7 +83,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
51
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
51
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
@ -0,0 +1,51 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 1) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_head_dim;
|
||||
|
||||
return !supports_sdpa_vector;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user