mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix perf regression
This commit is contained in:
@@ -463,7 +463,8 @@ void sdpa_vector_1pass_fallback(
|
|||||||
const array& v,
|
const array& v,
|
||||||
const float scale,
|
const float scale,
|
||||||
array& o,
|
array& o,
|
||||||
bool do_causal_ = false) {
|
bool do_causal,
|
||||||
|
const std::optional<array>& sinks) {
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
encoder.set_input_array(v);
|
encoder.set_input_array(v);
|
||||||
@@ -489,7 +490,7 @@ void sdpa_vector_1pass_fallback(
|
|||||||
dim3 block_dim(1024, 1, 1);
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||||
dispatch_headdim(params.D, [&](auto headdim) {
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
@@ -518,7 +519,8 @@ void sdpa_vector_2pass_fallback(
|
|||||||
const array& v,
|
const array& v,
|
||||||
const float scale,
|
const float scale,
|
||||||
array& o,
|
array& o,
|
||||||
bool do_causal_ = false) {
|
bool do_causal,
|
||||||
|
const std::optional<array>& sinks) {
|
||||||
cu::AttnParams params{
|
cu::AttnParams params{
|
||||||
/* int B = */ q.shape(0),
|
/* int B = */ q.shape(0),
|
||||||
/* int H = */ q.shape(1),
|
/* int H = */ q.shape(1),
|
||||||
@@ -559,7 +561,7 @@ void sdpa_vector_2pass_fallback(
|
|||||||
encoder.add_temporary(maxs);
|
encoder.add_temporary(maxs);
|
||||||
|
|
||||||
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||||
dispatch_headdim(params.D, [&](auto headdim) {
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
@@ -627,15 +629,16 @@ void sdpa_vector_fallback(
|
|||||||
const array& v,
|
const array& v,
|
||||||
const float scale,
|
const float scale,
|
||||||
array& o,
|
array& o,
|
||||||
bool do_causal_ = false) {
|
bool do_causal,
|
||||||
|
const std::optional<array>& sinks) {
|
||||||
int kL = k.shape(2);
|
int kL = k.shape(2);
|
||||||
|
|
||||||
if (kL > 1024) {
|
if (kL > 1024) {
|
||||||
return sdpa_vector_2pass_fallback(
|
return sdpa_vector_2pass_fallback(
|
||||||
s, encoder, q, k, v, scale, o, do_causal_);
|
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||||
} else {
|
} else {
|
||||||
return sdpa_vector_1pass_fallback(
|
return sdpa_vector_1pass_fallback(
|
||||||
s, encoder, q, k, v, scale, o, do_causal_);
|
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -703,6 +706,16 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Checks that the headdim dimension has stride 1.
|
||||||
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
|
return arr.strides(-1) == 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<array> sinks = std::nullopt;
|
||||||
|
if (has_sinks_) {
|
||||||
|
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||||
|
}
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) < 4) {
|
if (q_pre.shape(2) < 4) {
|
||||||
auto q_copy_unless = [](const array& arr) {
|
auto q_copy_unless = [](const array& arr) {
|
||||||
@@ -766,7 +779,8 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
encoder.add_temporary(cp);
|
encoder.add_temporary(cp);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(
|
||||||
|
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode should never reach here
|
// Full attention mode should never reach here
|
||||||
|
|||||||
@@ -116,17 +116,15 @@ template <typename T, int D, int V = D>
|
|||||||
score += q[j] * k[j];
|
score += q[j] * k[j];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
if (score < Limits<T>::finite_min) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (float_mask) {
|
if (float_mask) {
|
||||||
score += static_cast<U>(fmask[0]);
|
score += static_cast<U>(fmask[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
U factor = fast::exp(max_score - new_max);
|
bool is_neg_inf = new_max == -INFINITY;
|
||||||
U exp_score = fast::exp(score - new_max);
|
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
|
||||||
|
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
|
||||||
|
|
||||||
max_score = new_max;
|
max_score = new_max;
|
||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|||||||
Reference in New Issue
Block a user