fix donation in sdpa (#1587)

This commit is contained in:
Awni Hannun 2024-11-13 17:21:13 -08:00 committed by GitHub
parent dfa0b9aab4
commit b35f1e3c9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -203,12 +203,14 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
auto copy_unless = [&copies, &s](auto predicate, const array& arr) { copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(arr_copy);
return arr_copy; return copies.back();
} else { } else {
return arr; return arr;
} }
@ -237,9 +239,9 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) == 1) { if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre); const auto& q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable()) { if (q.is_donatable()) {
@ -247,15 +249,14 @@ void ScaledDotProductAttention::eval_gpu(
} else { } else {
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));
} }
sdpa_vector(s, d, q, k, v, o, scale_); sdpa_vector(s, d, q, k, v, o, scale_);
} }
// Full attention mode // Full attention mode
else { else {
auto q = copy_unless(is_matrix_contiguous, q_pre); const auto& q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);