diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index d2f021af5..8deff4f2f 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if (MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) diff --git a/mlx/backend/common/cpu_impl.cpp b/mlx/backend/common/cpu_impl.cpp index 00164810c..dee8127fe 100644 --- a/mlx/backend/common/cpu_impl.cpp +++ b/mlx/backend/common/cpu_impl.cpp @@ -12,8 +12,7 @@ std::function make_task(array arr, bool signal) { // Wait on inputs coming from different streams/devices. for (auto& input : arr.inputs()) { - if (input.event().valid() && - input.event().stream() != arr.primitive().stream()) { + if (input.event().valid() && input.event().stream() != stream) { input.event().wait(); } } diff --git a/mlx/backend/io/CMakeLists.txt b/mlx/backend/io/CMakeLists.txt new file mode 100644 index 000000000..69eaebfb5 --- /dev/null +++ b/mlx/backend/io/CMakeLists.txt @@ -0,0 +1,6 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp +) diff --git a/mlx/backend/io/io_impl.cpp b/mlx/backend/io/io_impl.cpp new file mode 100644 index 000000000..21724692b --- /dev/null +++ b/mlx/backend/io/io_impl.cpp @@ -0,0 +1,72 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/io/io_impl.h" +#include "mlx/backend/io/thread_pool.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::io { + +namespace { + +detail::ThreadPool& thread_pool() { + static std::unique_ptr pool_ptr; + + if (pool_ptr == nullptr) { + pool_ptr = std::make_unique(4); + } + + return *pool_ptr; +} + +} // namespace + +std::function make_task(array arr, bool signal) { + return [arr = std::move(arr), signal]() mutable { + auto stream = arr.primitive().stream(); + + // Wait on inputs coming from different streams/devices. + for (auto& input : arr.inputs()) { + if (input.event().valid() && input.event().stream() != stream) { + input.event().wait(); + } + } + + // Task computation actually starting. + scheduler::notify_new_task(stream); + + // Schedule the computation + auto inputs = arr.inputs(); + auto outputs = arr.outputs(); + thread_pool().enqueue( + [arr = std::move(arr), inputs, outputs, signal, stream]() mutable { + // Perform the computation + arr.primitive().eval_io(inputs, outputs); + + if (!arr.is_tracer()) { + arr.detach(); + } + + if (signal) { + thread_pool().barrier( + [arr = std::move(arr)]() { arr.event().signal(); }); + } + + // Task computation done. + scheduler::notify_task_completion(stream); + }, + inputs, + outputs); + }; +} + +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p) { + return [p = std::move(p)]() { + thread_pool().barrier().wait(); + p->set_value(); + }; +} + +} // namespace mlx::core::io diff --git a/mlx/backend/io/io_impl.h b/mlx/backend/io/io_impl.h new file mode 100644 index 000000000..8c32a00ba --- /dev/null +++ b/mlx/backend/io/io_impl.h @@ -0,0 +1,18 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" + +namespace mlx::core::io { + +std::function make_task(array arr, bool signal); +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p); + +} // namespace mlx::core::io diff --git a/mlx/backend/io/thread_pool.cpp b/mlx/backend/io/thread_pool.cpp new file mode 100644 index 000000000..b7de9af3c --- /dev/null +++ b/mlx/backend/io/thread_pool.cpp @@ -0,0 +1,202 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/io/thread_pool.h" + +namespace mlx::core::io::detail { + +ThreadPool::ThreadPool(int workers) : stop_(false) { + for (int i = 0; i < workers; i++) { + workers_.emplace_back(&ThreadPool::worker, this, i); + } +} + +ThreadPool::~ThreadPool() { + stop_ = true; + for (auto& cv : queue_cvs_) { + cv.notify_one(); + } + + for (auto& t : workers_) { + if (t.joinable()) { + t.join(); + } + } +} + +std::future ThreadPool::enqueue( + std::function task, + const std::vector& inputs, + const std::vector& outputs) { + std::vector barriers; + for (int i = 0; i < output_sets_.size(); i++) { + std::lock_guard lock(set_mutexes_[i]); + + for (auto& a : inputs) { + if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) { + barriers.push_back(i); + break; + } + } + } + + // Case 1: Barriers is empty so try to add it to the smallest queue + if (barriers.empty()) { + auto min_queue = std::min_element( + task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) { + return left.size() < right.size(); + }); + int worker_idx = std::distance(task_queues_.begin(), min_queue); + + add_outputs_to_worker(outputs, worker_idx); + return enqueue( + remove_outputs_when_done(std::move(task), outputs, worker_idx), + worker_idx); + } + + // Case 2: Barriers has only one queue so put that into that queue + if (barriers.size() == 1) { + int worker_idx = barriers[0]; + add_outputs_to_worker(outputs, worker_idx); + return enqueue( + remove_outputs_when_done(std::move(task), outputs, worker_idx), + worker_idx); + } + + // Case 3: We need to add a barrier before our task and add it to the + // smallest queue of the barriers. + auto min_queue = std::min_element( + barriers.begin(), barriers.end(), [this](auto left, auto right) { + return task_queues_[left].size() < task_queues_[right].size(); + }); + int worker_idx = *min_queue; + barriers.erase(min_queue); + std::shared_future queue_barrier = + barrier(barriers); // We shouldn't need shared future here + add_outputs_to_worker(outputs, worker_idx); + return enqueue( + remove_outputs_when_done( + [queue_barrier = std::move(queue_barrier), + og_task = std::move(task)]() { + queue_barrier.wait(); + og_task(); + }, + outputs, + worker_idx), + worker_idx); +} + +std::future ThreadPool::enqueue( + std::function task, + int worker_idx) { + std::packaged_task pt(std::move(task)); + std::future result = pt.get_future(); + { + std::lock_guard lock(queue_mutexes_[worker_idx]); + task_queues_[worker_idx].emplace(std::move(pt)); + } + queue_cvs_[worker_idx].notify_one(); + return result; +} + +void ThreadPool::add_outputs_to_worker( + const std::vector& outputs, + int worker_idx) { + if (outputs.size() == 0) { + return; + } + + std::lock_guard lock(set_mutexes_[worker_idx]); + for (auto& a : outputs) { + output_sets_[worker_idx].insert(a.buffer().ptr()); + } +} + +std::function ThreadPool::remove_outputs_when_done( + std::function task, + const std::vector& outputs, + int worker_idx) { + if (outputs.size() == 0) { + return task; + } + + std::vector output_buffers; + for (auto& a : outputs) { + output_buffers.push_back(a.buffer().ptr()); + } + + return [og_task = std::move(task), + buffers = std::move(output_buffers), + worker_idx, + this]() { + og_task(); + { + std::lock_guard lock(set_mutexes_[worker_idx]); + for (auto b : buffers) { + output_sets_[worker_idx].erase(b); + } + } + }; +} + +std::future ThreadPool::barrier( + const std::vector& worker_ids, + std::function on_barrier) { + auto workers = std::make_shared>(worker_ids.size()); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + for (auto idx : worker_ids) { + enqueue( + [workers, promise, on_barrier = std::move(on_barrier)]() { + (*workers)--; + if (*workers <= 0) { + on_barrier(); + promise->set_value(); + } + }, + idx); + } + + return future; +} + +std::future ThreadPool::barrier(const std::vector& worker_ids) { + auto noop = []() {}; + return barrier(worker_ids, std::move(noop)); +} + +std::future ThreadPool::barrier(std::function on_barrier) { + std::vector worker_ids(workers_.size()); + std::iota(worker_ids.begin(), worker_ids.end(), 0); + return barrier(worker_ids, std::move(on_barrier)); +} + +std::future ThreadPool::barrier() { + auto noop = []() {}; + return barrier(std::move(noop)); +} + +void ThreadPool::worker(int idx) { + while (true) { + std::packaged_task task; + { + std::unique_lock lock(queue_mutexes_[idx]); + queue_cvs_[idx].wait( + lock, [this, idx]() { return stop_ || !task_queues_[idx].empty(); }); + if (task_queues_[idx].empty()) { + if (stop_) { + break; + } else { + continue; + } + } + task = std::move(task_queues_[idx].front()); + task_queues_[idx].pop(); + } + task(); + } +} + +} // namespace mlx::core::io::detail diff --git a/mlx/backend/io/thread_pool.h b/mlx/backend/io/thread_pool.h new file mode 100644 index 000000000..d1da25c4d --- /dev/null +++ b/mlx/backend/io/thread_pool.h @@ -0,0 +1,52 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" + +namespace mlx::core::io::detail { + +class ThreadPool { + public: + explicit ThreadPool(int workers); + ~ThreadPool(); + + ThreadPool(ThreadPool&&) = delete; + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(ThreadPool&&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + + std::future enqueue( + std::function task, + const std::vector& inputs, + const std::vector& outputs); + std::future barrier( + const std::vector& worker_ids, + std::function on_barrier); + std::future barrier(const std::vector& worker_ids); + std::future barrier(std::function on_barrier); + std::future barrier(); + + private: + std::future enqueue(std::function task, int worker_idx); + void add_outputs_to_worker(const std::vector& outputs, int worker_idx); + std::function remove_outputs_when_done( + std::function task, + const std::vector& outputs, + int worker_idx); + void worker(int idx); + + std::vector>> task_queues_; + std::vector queue_mutexes_; + std::vector queue_cvs_; + std::vector workers_; + std::vector set_mutexes_; + std::vector> output_sets_; + bool stop_; +}; + +} // namespace mlx::core::io::detail