redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -3,7 +3,8 @@
#include <algorithm>
#include <utility>
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace {
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),
offset = offset_,
reader = reader_,
swap_endianness_ = swap_endianness_]() mutable {
reader->read(out_ptr, size * itemsize, offset);
if (swap_endianness_) {
switch (itemsize) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
}
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
}
} // namespace mlx::core