diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731912d69..5a04af0bd 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -605,7 +605,7 @@ array scaled_dot_product_attention( threshold = std::max(1, memory_efficient_threshold.value()); } - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s]( + auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -856,8 +856,12 @@ array affine_dequantize( auto s = to_stream(s_); - auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s]( - const std::vector& inputs) -> std::vector { + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + s](const std::vector& inputs) mutable -> std::vector { auto w = inputs[0]; auto& scales = inputs[1]; auto& biases = inputs[2];