From a6c3b38fba762c105826b93167b171ef01eef1a4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 28 Aug 2024 14:21:55 -0700 Subject: [PATCH] Async load (#1372) * async load * async load --- mlx/backend/common/load.cpp | 22 +++++++++++++++------- mlx/backend/common/load.h | 14 ++++++++++++++ mlx/backend/metal/primitives.cpp | 21 ++++++++++++++++++++- 3 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 mlx/backend/common/load.h diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index c4130233c..2a10ed08a 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -5,11 +5,9 @@ #include #include "mlx/allocator.h" -#include "mlx/io/load.h" +#include "mlx/backend/common/load.h" #include "mlx/primitives.h" -namespace mlx::core { - namespace { template @@ -29,11 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) { } // namespace -void Load::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); +namespace mlx::core { - reader_->read(out.data(), out.nbytes(), offset_); +void load( + array& out, + size_t offset, + const std::shared_ptr& reader, + bool swap_endianness_) { + reader->read(out.data(), out.nbytes(), offset); if (swap_endianness_) { switch (out.itemsize()) { @@ -50,4 +51,11 @@ void Load::eval(const std::vector& inputs, array& out) { } } +void Load::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + load(out, offset_, reader_, swap_endianness_); +} + } // namespace mlx::core diff --git a/mlx/backend/common/load.h b/mlx/backend/common/load.h new file mode 100644 index 000000000..5806cc693 --- /dev/null +++ b/mlx/backend/common/load.h @@ -0,0 +1,14 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/array.h" +#include "mlx/io/load.h" + +namespace mlx::core { + +void load( + array& out, + size_t offset, + const std::shared_ptr& reader, + bool swap_endianess); + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 1b7f8122a..6580dd7a7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -4,12 +4,14 @@ #include #include +#include "mlx/backend/common/load.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" +#include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core { @@ -197,7 +199,24 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { } void Load::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); + static Stream io_stream = new_stream(Device::cpu); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto task = [out = out, + offset = offset_, + reader = reader_, + swap_endianness = swap_endianness_]() mutable { + load(out, offset, reader, swap_endianness); + out.event().signal(); + }; + + scheduler::enqueue(io_stream, std::move(task)); + auto& d = metal::device(stream().device); + d.end_encoding(stream().index); + auto command_buffer = d.get_command_buffer(stream().index); + command_buffer->encodeWait( + static_cast(out.event().raw_event().get()), + out.event().value()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) {