mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness_ = swap_endianness_]() mutable {
|
||||
reader->read(out_ptr, size * itemsize, offset);
|
||||
if (swap_endianness_) {
|
||||
switch (itemsize) {
|
||||
case 2:
|
||||
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user