mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -316,6 +317,76 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w] {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
@@ -324,64 +395,111 @@ void _qmm_dispatch(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
int batch_size = x.size() / (K * M);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
indices_size = lhs_indices.size(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w]() {
|
||||
for (int i = 0; i < indices_size; i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
@@ -394,68 +512,54 @@ void _bs_qmm_dispatch(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,13 +573,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -485,7 +590,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
_qmm_dispatch(
|
||||
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -498,15 +606,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
&temps](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -526,31 +636,30 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
transpose_,
|
||||
stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
const T* w,
|
||||
U* out,
|
||||
T* scales,
|
||||
T* biases,
|
||||
int bits,
|
||||
int group_size) {
|
||||
const T* w = w_.data<T>();
|
||||
|
||||
auto out = out_.data<U>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
|
||||
int group_size,
|
||||
size_t w_size) {
|
||||
float n_bins = (1 << bits) - 1;
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
@@ -593,20 +702,50 @@ void quantize(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void dispatch_quantize(
|
||||
const array& w,
|
||||
array& out,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
Stream stream) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w_ptr,
|
||||
out_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
bits,
|
||||
group_size,
|
||||
w_size = w.size()]() {
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
|
||||
});
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -616,27 +755,35 @@ void fast::AffineQuantize::eval_cpu(
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(w);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user