mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -5,67 +5,83 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
dst += N;
|
||||
void unary_op(const array& a, array& out, Op) {
|
||||
set_unary_output_data(a, out);
|
||||
const T* src = a.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([src,
|
||||
dst,
|
||||
contig = a.flags().contiguous,
|
||||
data_size = a.data_size(),
|
||||
size = a.size(),
|
||||
shapes = a.shape(),
|
||||
strides = a.strides()]() mutable {
|
||||
auto ndim = shapes.size();
|
||||
if (contig) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (data_size >= N) {
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
data_size -= N;
|
||||
src += N;
|
||||
dst += N;
|
||||
}
|
||||
while (data_size > 0) {
|
||||
*dst = Op{}(*src);
|
||||
data_size--;
|
||||
dst++;
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? shapes.back() : 1;
|
||||
size_t stride = ndim > 0 ? strides.back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(shapes, strides, ndim - 1);
|
||||
for (size_t elem = 0; elem < size; elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U* dst = out.data<U>();
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
unary_op(a_ptr, dst, op, shape, stride);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
|
||||
Reference in New Issue
Block a user