From be36f136de1b612058da29ebb03a51be1357edca Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 1 May 2024 17:04:52 -0700 Subject: [PATCH] Add io device and cpu::make_task --- mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/cpu_impl.cpp | 48 +++++++++++++++++++++++++++++++ mlx/backend/common/cpu_impl.h | 18 ++++++++++++ mlx/device.h | 2 ++ mlx/primitives.cpp | 10 +++++++ mlx/primitives.h | 10 +++++++ mlx/scheduler.cpp | 12 +++++--- mlx/transforms.cpp | 36 +++++++---------------- mlx/utils.cpp | 3 ++ 9 files changed, 111 insertions(+), 29 deletions(-) create mode 100644 mlx/backend/common/cpu_impl.cpp create mode 100644 mlx/backend/common/cpu_impl.h diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index ea0babf18..b48413108 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -55,6 +55,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_impl.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/cpu_impl.cpp b/mlx/backend/common/cpu_impl.cpp new file mode 100644 index 000000000..00164810c --- /dev/null +++ b/mlx/backend/common/cpu_impl.cpp @@ -0,0 +1,48 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/common/cpu_impl.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::cpu { + +std::function make_task(array arr, bool signal) { + return [arr = std::move(arr), signal]() mutable { + auto stream = arr.primitive().stream(); + + // Wait on inputs coming from different streams/devices. + for (auto& input : arr.inputs()) { + if (input.event().valid() && + input.event().stream() != arr.primitive().stream()) { + input.event().wait(); + } + } + + // Task computation actually starting. + scheduler::notify_new_task(stream); + + // Perform the computation + auto outputs = arr.outputs(); + arr.primitive().eval_cpu(arr.inputs(), outputs); + + // Check if we need to detach and signal other arrays waiting for the + // result to be ready. + if (!arr.is_tracer()) { + arr.detach(); + } + if (signal) { + arr.event().signal(); + } + + // Task computation done. + scheduler::notify_task_completion(stream); + }; +} + +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p) { + return [p = std::move(p)]() { p->set_value(); }; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/common/cpu_impl.h b/mlx/backend/common/cpu_impl.h new file mode 100644 index 000000000..0a11a3a6c --- /dev/null +++ b/mlx/backend/common/cpu_impl.h @@ -0,0 +1,18 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" + +namespace mlx::core::cpu { + +std::function make_task(array arr, bool signal); +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p); + +} // namespace mlx::core::cpu diff --git a/mlx/device.h b/mlx/device.h index 2a09195f4..fdc33be24 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -8,10 +8,12 @@ struct Device { enum class DeviceType { cpu, gpu, + io, }; static constexpr DeviceType cpu = DeviceType::cpu; static constexpr DeviceType gpu = DeviceType::gpu; + static constexpr DeviceType io = DeviceType::io; Device(DeviceType type, int index = 0) : type(type), index(index) {}; diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 1f50d1e9c..c522cd41a 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -106,6 +106,16 @@ std::tuple vmap_ternary_op( } // namespace +void Primitive::eval_io( + const std::vector& inputs, + std::vector& outputs) { + std::ostringstream msg; + msg << "[Primitive::eval_io] Not implemented for "; + print(msg); + msg << "."; + throw std::invalid_argument(msg.str()); +} + std::vector Primitive::jvp( const std::vector&, const std::vector&, diff --git a/mlx/primitives.h b/mlx/primitives.h index 7d0aca52b..6bc31a31b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -73,6 +73,16 @@ class Primitive { const std::vector& inputs, std::vector& outputs) = 0; + /** + * Some primitives are computed by an IO device (disk, network, camera etc). + * + * Like in eval_cpu/gpu the eval_io function is responsible for allocating + * the space for the array. + */ + virtual void eval_io( + const std::vector& inputs, + std::vector& outputs); + /** * The Jacobian-vector product. */ diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 9e4342583..767ff855d 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" +#include "mlx/backend/common/cpu_impl.h" #include "mlx/backend/metal/metal.h" namespace mlx::core { @@ -36,10 +37,13 @@ Stream new_stream() { void synchronize(Stream s) { auto p = std::make_shared>(); std::future f = p->get_future(); - if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); - } else { - scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); + switch (s.device.type) { + case mlx::core::Device::cpu: + scheduler::enqueue(s, cpu::make_synchronize_task(s, std::move(p))); + break; + case mlx::core::Device::gpu: + scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); + break; } f.wait(); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 005402a98..c93a42fd1 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -8,6 +8,7 @@ #include #include +#include "mlx/backend/common/cpu_impl.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -137,32 +138,17 @@ array eval_impl(std::vector outputs, bool async) { std::vector> arr_deps; bool signal = needs_signal.find(arr.id()) != needs_signal.end(); - if (arr.primitive().device() == Device::gpu) { - if (!metal::is_available()) { - throw std::runtime_error("Metal GPU is not available."); + switch (arr.primitive().device().type) { + case Device::gpu: { + if (!metal::is_available()) { + throw std::runtime_error("Metal GPU is not available."); + } + scheduler::enqueue(stream, metal::make_task(std::move(arr), signal)); + break; } - scheduler::enqueue(stream, metal::make_task(std::move(arr), signal)); - } else { - auto task = [arr = std::move(arr), stream, signal]() mutable { - for (auto& input : arr.inputs()) { - if (input.event().valid() && - input.event().stream() != arr.primitive().stream()) { - input.event().wait(); - } - } - scheduler::notify_new_task(stream); - auto outputs = arr.outputs(); - arr.primitive().eval_cpu(arr.inputs(), outputs); - if (!arr.is_tracer()) { - arr.detach(); - } - if (signal) { - arr.event().signal(); - } - - scheduler::notify_task_completion(stream); - }; - scheduler::enqueue(stream, std::move(task)); + case Device::cpu: + scheduler::enqueue(stream, cpu::make_task(std::move(arr), signal)); + break; } } return synchronizer; diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 34882884b..5253ea113 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -133,6 +133,9 @@ std::ostream& operator<<(std::ostream& os, const Device& d) { case Device::gpu: os << "gpu"; break; + case Device::io: + os << "io"; + break; } os << ", " << d.index << ")"; return os;