mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
fix fallback (#1646)
This commit is contained in:
parent
e047fd977d
commit
c5b0928c1f
10
mlx/fast.cpp
10
mlx/fast.cpp
@ -605,7 +605,7 @@ array scaled_dot_product_attention(
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
@ -856,8 +856,12 @@ array affine_dequantize(
|
||||
|
||||
auto s = to_stream(s_);
|
||||
|
||||
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto w = inputs[0];
|
||||
auto& scales = inputs[1];
|
||||
auto& biases = inputs[2];
|
||||
|
Loading…
Reference in New Issue
Block a user