mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 18:51:14 +08:00
fix fallback (#1646)
This commit is contained in:
parent
e047fd977d
commit
c5b0928c1f
10
mlx/fast.cpp
10
mlx/fast.cpp
@ -605,7 +605,7 @@ array scaled_dot_product_attention(
|
|||||||
threshold = std::max(1, memory_efficient_threshold.value());
|
threshold = std::max(1, memory_efficient_threshold.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
|
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||||
int n_repeats = n_q_heads / n_kv_heads;
|
int n_repeats = n_q_heads / n_kv_heads;
|
||||||
@ -856,8 +856,12 @@ array affine_dequantize(
|
|||||||
|
|
||||||
auto s = to_stream(s_);
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
|
auto fallback =
|
||||||
const std::vector<array>& inputs) -> std::vector<array> {
|
[wshape = std::move(wshape),
|
||||||
|
sshape = std::move(sshape),
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||||
auto w = inputs[0];
|
auto w = inputs[0];
|
||||||
auto& scales = inputs[1];
|
auto& scales = inputs[1];
|
||||||
auto& biases = inputs[2];
|
auto& biases = inputs[2];
|
||||||
|
Loading…
Reference in New Issue
Block a user