fix perf regression

This commit is contained in:
Awni Hannun
2025-09-09 13:19:58 -07:00
committed by Awni Hannun
parent 836f019d3b
commit ce54db388f
2 changed files with 25 additions and 13 deletions

View File

@@ -463,7 +463,8 @@ void sdpa_vector_1pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
@@ -489,7 +490,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -518,7 +519,8 @@ void sdpa_vector_2pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
@@ -559,7 +561,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -627,15 +629,16 @@ void sdpa_vector_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
}
}
@@ -703,6 +706,16 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -766,7 +779,8 @@ void ScaledDotProductAttention::eval_gpu(
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
}
// Full attention mode should never reach here

View File

@@ -116,17 +116,15 @@ template <typename T, int D, int V = D>
score += q[j] * k[j];
}
score = simd_sum(score);
if (score < Limits<T>::finite_min) {
continue;
}
if (float_mask) {
score += static_cast<U>(fmask[0]);
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;