redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
@@ -103,11 +104,11 @@ struct StridedIterator {
T* ptr_;
};
template <typename T, typename IdxT = uint32_t>
void sort(const array& in, array& out, int axis) {
template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
@@ -126,19 +127,27 @@ void sort(const array& in, array& out, int axis) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
std::stable_sort(st, ed);
src_it.step();
}
});
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
void argsort(const array& in, array& out, int axis, Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -167,35 +176,48 @@ void argsort(const array& in, array& out, int axis) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
in_it.step();
out_it.step();
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
template <typename T, typename IdxT = uint32_t>
void partition(const array& in, array& out, int axis, int kth) {
template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
@@ -216,20 +238,34 @@ void partition(const array& in, array& out, int axis, int kth) {
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
src_it.step();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
std::nth_element(st, md, ed);
}
});
}
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
void argpartition(
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -260,29 +296,43 @@ void argpartition(const array& in, array& out, int axis, int kth) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
} // namespace
@@ -293,33 +343,33 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
return argsort<bool>(in, out, axis_, stream());
case uint8:
return argsort<uint8_t>(in, out, axis_);
return argsort<uint8_t>(in, out, axis_, stream());
case uint16:
return argsort<uint16_t>(in, out, axis_);
return argsort<uint16_t>(in, out, axis_, stream());
case uint32:
return argsort<uint32_t>(in, out, axis_);
return argsort<uint32_t>(in, out, axis_, stream());
case uint64:
return argsort<uint64_t>(in, out, axis_);
return argsort<uint64_t>(in, out, axis_, stream());
case int8:
return argsort<int8_t>(in, out, axis_);
return argsort<int8_t>(in, out, axis_, stream());
case int16:
return argsort<int16_t>(in, out, axis_);
return argsort<int16_t>(in, out, axis_, stream());
case int32:
return argsort<int32_t>(in, out, axis_);
return argsort<int32_t>(in, out, axis_, stream());
case int64:
return argsort<int64_t>(in, out, axis_);
return argsort<int64_t>(in, out, axis_, stream());
case float32:
return argsort<float>(in, out, axis_);
return argsort<float>(in, out, axis_, stream());
case float64:
return argsort<double>(in, out, axis_);
return argsort<double>(in, out, axis_, stream());
case float16:
return argsort<float16_t>(in, out, axis_);
return argsort<float16_t>(in, out, axis_, stream());
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
return argsort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return argsort<complex64_t>(in, out, axis_);
return argsort<complex64_t>(in, out, axis_, stream());
}
}
@@ -329,33 +379,33 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_);
return sort<bool>(in, out, axis_, stream());
case uint8:
return sort<uint8_t>(in, out, axis_);
return sort<uint8_t>(in, out, axis_, stream());
case uint16:
return sort<uint16_t>(in, out, axis_);
return sort<uint16_t>(in, out, axis_, stream());
case uint32:
return sort<uint32_t>(in, out, axis_);
return sort<uint32_t>(in, out, axis_, stream());
case uint64:
return sort<uint64_t>(in, out, axis_);
return sort<uint64_t>(in, out, axis_, stream());
case int8:
return sort<int8_t>(in, out, axis_);
return sort<int8_t>(in, out, axis_, stream());
case int16:
return sort<int16_t>(in, out, axis_);
return sort<int16_t>(in, out, axis_, stream());
case int32:
return sort<int32_t>(in, out, axis_);
return sort<int32_t>(in, out, axis_, stream());
case int64:
return sort<int64_t>(in, out, axis_);
return sort<int64_t>(in, out, axis_, stream());
case float32:
return sort<float>(in, out, axis_);
return sort<float>(in, out, axis_, stream());
case float64:
return sort<double>(in, out, axis_);
return sort<double>(in, out, axis_, stream());
case float16:
return sort<float16_t>(in, out, axis_);
return sort<float16_t>(in, out, axis_, stream());
case bfloat16:
return sort<bfloat16_t>(in, out, axis_);
return sort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return sort<complex64_t>(in, out, axis_);
return sort<complex64_t>(in, out, axis_, stream());
}
}
@@ -365,33 +415,33 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
return argpartition<bool>(in, out, axis_, kth_, stream());
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
return argpartition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
return argpartition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
return argpartition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
return argpartition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return argpartition<float>(in, out, axis_, kth_);
return argpartition<float>(in, out, axis_, kth_, stream());
case float64:
return argpartition<double>(in, out, axis_, kth_);
return argpartition<double>(in, out, axis_, kth_, stream());
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
return argpartition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
}
}
@@ -401,33 +451,33 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_);
return partition<bool>(in, out, axis_, kth_, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_);
return partition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return partition<uint16_t>(in, out, axis_, kth_);
return partition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return partition<uint32_t>(in, out, axis_, kth_);
return partition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return partition<uint64_t>(in, out, axis_, kth_);
return partition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return partition<int8_t>(in, out, axis_, kth_);
return partition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return partition<int16_t>(in, out, axis_, kth_);
return partition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return partition<int32_t>(in, out, axis_, kth_);
return partition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return partition<int64_t>(in, out, axis_, kth_);
return partition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return partition<float>(in, out, axis_, kth_);
return partition<float>(in, out, axis_, kth_, stream());
case float64:
return partition<double>(in, out, axis_, kth_);
return partition<double>(in, out, axis_, kth_, stream());
case float16:
return partition<float16_t>(in, out, axis_, kth_);
return partition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_);
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return partition<complex64_t>(in, out, axis_, kth_);
return partition<complex64_t>(in, out, axis_, kth_, stream());
}
}