From b35f1e3c9cb1f8ed39bd222b44a8b9df6f6ca96c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 13 Nov 2024 17:21:13 -0800 Subject: [PATCH] fix donation in sdpa (#1587) --- .../metal/scaled_dot_product_attention.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3071650a5..acc20b323 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -203,12 +203,14 @@ void ScaledDotProductAttention::eval_gpu( // Define some copy functions to ensure the layout of the inputs is as // expected. - auto copy_unless = [&copies, &s](auto predicate, const array& arr) { + copies.reserve(3); + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); copies.push_back(arr_copy); - return arr_copy; + return copies.back(); } else { return arr; } @@ -237,9 +239,9 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) == 1) { - auto q = copy_unless(is_contiguous, q_pre); - auto k = copy_unless(is_contiguous_except_seq_len, k_pre); - auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + const auto& q = copy_unless(is_contiguous, q_pre); + const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); + const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); // Donate the query if possible if (q.is_donatable()) { @@ -247,15 +249,14 @@ void ScaledDotProductAttention::eval_gpu( } else { o.set_data(allocator::malloc_or_wait(o.nbytes())); } - sdpa_vector(s, d, q, k, v, o, scale_); } // Full attention mode else { - auto q = copy_unless(is_matrix_contiguous, q_pre); - auto k = copy_unless(is_matrix_contiguous, k_pre); - auto v = copy_unless(is_matrix_contiguous, v_pre); + const auto& q = copy_unless(is_matrix_contiguous, q_pre); + const auto& k = copy_unless(is_matrix_contiguous, k_pre); + const auto& v = copy_unless(is_matrix_contiguous, v_pre); o.set_data(allocator::malloc_or_wait(o.nbytes())); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);