From c5b0928c1fa97a61e95e9dc5b84d9a8c7042f3c2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 5 Dec 2024 11:59:53 -0800 Subject: [PATCH] fix fallback (#1646) --- mlx/fast.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731912d69..5a04af0bd 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -605,7 +605,7 @@ array scaled_dot_product_attention( threshold = std::max(1, memory_efficient_threshold.value()); } - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s]( + auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -856,8 +856,12 @@ array affine_dequantize( auto s = to_stream(s_); - auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s]( - const std::vector& inputs) -> std::vector { + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + s](const std::vector& inputs) mutable -> std::vector { auto w = inputs[0]; auto& scales = inputs[1]; auto& biases = inputs[2];