redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -5,67 +5,83 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
out[i] = Op{}(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
U* dst = out.data<U>();
constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
size -= N;
a_ptr += N;
dst += N;
void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>();
U* dst = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([src,
dst,
contig = a.flags().contiguous,
data_size = a.data_size(),
size = a.size(),
shapes = a.shape(),
strides = a.strides()]() mutable {
auto ndim = shapes.size();
if (contig) {
constexpr int N = simd::max_size<T>;
while (data_size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
data_size -= N;
src += N;
dst += N;
}
while (data_size > 0) {
*dst = Op{}(*src);
data_size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? shapes.back() : 1;
size_t stride = ndim > 0 ? strides.back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(shapes, strides, ndim - 1);
for (size_t elem = 0; elem < size; elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
while (size > 0) {
*dst = op(*a_ptr);
size--;
dst++;
a_ptr++;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
}
}
});
}
template <typename Op>