Async load (#1372)

* async load

* async load
This commit is contained in:
Awni Hannun 2024-08-28 14:21:55 -07:00 committed by GitHub
parent fcb65a3897
commit a6c3b38fba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 8 deletions

View File

@ -5,11 +5,9 @@
#include <utility> #include <utility>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/io/load.h" #include "mlx/backend/common/load.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core {
namespace { namespace {
template <const uint8_t scalar_size> template <const uint8_t scalar_size>
@ -29,11 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
} // namespace } // namespace
void Load::eval(const std::vector<array>& inputs, array& out) { namespace mlx::core {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->read(out.data<char>(), out.nbytes(), offset_); void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) { if (swap_endianness_) {
switch (out.itemsize()) { switch (out.itemsize()) {
@ -50,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core } // namespace mlx::core

14
mlx/backend/common/load.h Normal file
View File

@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@ -4,12 +4,14 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/load.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@ -197,7 +199,24 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void Load::eval_gpu(const std::vector<array>& inputs, array& out) { void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); static Stream io_stream = new_stream(Device::cpu);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [out = out,
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {
load(out, offset, reader, swap_endianness);
out.event().signal();
};
scheduler::enqueue(io_stream, std::move(task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
} }
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {