mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add sdpa with sinks (#2558)
* add sdpa with sinks * fix 2 pass * fix matrix sdpa * fix perf regression * add to cuda (#2580)
This commit is contained in:
@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
const T* K,
|
||||
const T* V,
|
||||
T* O,
|
||||
const T* sinks,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * M_LOG2E;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
const T* sinks,
|
||||
float* partials,
|
||||
float* sums,
|
||||
float* maxs,
|
||||
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0 && block_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
encoder.set_output_array(o);
|
||||
|
||||
cu::AttnParams params{
|
||||
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
|
||||
dim3 block_dim(1024, 1, 1);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
params);
|
||||
});
|
||||
});
|
||||
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
cu::AttnParams params{
|
||||
/* int B = */ q.shape(0),
|
||||
/* int H = */ q.shape(1),
|
||||
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.add_temporary(maxs);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.set_output_array(sums);
|
||||
encoder.set_output_array(maxs);
|
||||
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
} else {
|
||||
return sdpa_vector_1pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||
/* bool col_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ o.size() == o.shape(3),
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
o.size(),
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
|
||||
Reference in New Issue
Block a user