Start sdpa vector

This commit is contained in:
Angelos Katharopoulos 2025-06-15 21:58:34 -07:00
parent bc53f8293f
commit 870208eff5
3 changed files with 52 additions and 12 deletions

View File

@ -33,6 +33,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@ -43,17 +43,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
});
}
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
@ -94,7 +83,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/fast_primitives.h"
namespace mlx::core {
namespace cu {} // namespace cu
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool supports_sdpa_vector = (query_sequence_length <= 1) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !supports_sdpa_vector;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
}
} // namespace fast
} // namespace mlx::core