fix fallback (#1646)

This commit is contained in:
Awni Hannun 2024-12-05 11:59:53 -08:00 committed by GitHub
parent e047fd977d
commit c5b0928c1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -605,7 +605,7 @@ array scaled_dot_product_attention(
threshold = std::max(1, memory_efficient_threshold.value());
}
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
@ -856,8 +856,12 @@ array affine_dequantize(
auto s = to_stream(s_);
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto fallback =
[wshape = std::move(wshape),
sshape = std::move(sshape),
group_size,
bits,
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
auto w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];