mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix donation in sdpa (#1587)
This commit is contained in:
parent
dfa0b9aab4
commit
b35f1e3c9c
@ -203,12 +203,14 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
|
||||
copies.reserve(3);
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
return copies.back();
|
||||
} else {
|
||||
return arr;
|
||||
}
|
||||
@ -237,9 +239,9 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
auto q = copy_unless(is_contiguous, q_pre);
|
||||
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
@ -247,15 +249,14 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
else {
|
||||
auto q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
auto k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
auto v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
|
Loading…
Reference in New Issue
Block a user