fix donation in sdpa (#1587)

This commit is contained in:
Awni Hannun 2024-11-13 17:21:13 -08:00 committed by GitHub
parent dfa0b9aab4
commit b35f1e3c9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -203,12 +203,14 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
return copies.back();
} else {
return arr;
}
@ -237,9 +239,9 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
const auto& q = copy_unless(is_contiguous, q_pre);
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
@ -247,15 +249,14 @@ void ScaledDotProductAttention::eval_gpu(
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
}
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);