diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3071650a5..acc20b323 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -203,12 +203,14 @@ void ScaledDotProductAttention::eval_gpu( // Define some copy functions to ensure the layout of the inputs is as // expected. - auto copy_unless = [&copies, &s](auto predicate, const array& arr) { + copies.reserve(3); + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); copies.push_back(arr_copy); - return arr_copy; + return copies.back(); } else { return arr; } @@ -237,9 +239,9 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) == 1) { - auto q = copy_unless(is_contiguous, q_pre); - auto k = copy_unless(is_contiguous_except_seq_len, k_pre); - auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + const auto& q = copy_unless(is_contiguous, q_pre); + const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); + const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); // Donate the query if possible if (q.is_donatable()) { @@ -247,15 +249,14 @@ void ScaledDotProductAttention::eval_gpu( } else { o.set_data(allocator::malloc_or_wait(o.nbytes())); } - sdpa_vector(s, d, q, k, v, o, scale_); } // Full attention mode else { - auto q = copy_unless(is_matrix_contiguous, q_pre); - auto k = copy_unless(is_matrix_contiguous, k_pre); - auto v = copy_unless(is_matrix_contiguous, v_pre); + const auto& q = copy_unless(is_matrix_contiguous, q_pre); + const auto& k = copy_unless(is_matrix_contiguous, k_pre); + const auto& v = copy_unless(is_matrix_contiguous, v_pre); o.set_data(allocator::malloc_or_wait(o.nbytes())); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);