Async load (#1372)

* async load

* async load
This commit is contained in:
Awni Hannun
2024-08-28 14:21:55 -07:00
committed by GitHub
parent fcb65a3897
commit a6c3b38fba
3 changed files with 49 additions and 8 deletions

View File

@@ -5,11 +5,9 @@
#include <utility>
#include "mlx/allocator.h"
#include "mlx/io/load.h"
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
@@ -29,11 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
} // namespace
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
namespace mlx::core {
reader_->read(out.data<char>(), out.nbytes(), offset_);
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) {
switch (out.itemsize()) {
@@ -50,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
}
}
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core