From bae159738f7380a8fb5ad97b4a2a57974d6dded2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 8 May 2024 22:17:25 -0700 Subject: [PATCH] Working IO primitives --- mlx/backend/io/thread_pool.cpp | 46 ++++++++++++++++++++++------------ mlx/backend/io/thread_pool.h | 4 +-- mlx/scheduler.cpp | 4 +++ mlx/scheduler.h | 1 + mlx/transforms.cpp | 1 + 5 files changed, 38 insertions(+), 18 deletions(-) diff --git a/mlx/backend/io/thread_pool.cpp b/mlx/backend/io/thread_pool.cpp index b7de9af3c..79a68c6f3 100644 --- a/mlx/backend/io/thread_pool.cpp +++ b/mlx/backend/io/thread_pool.cpp @@ -6,7 +6,13 @@ namespace mlx::core::io::detail { -ThreadPool::ThreadPool(int workers) : stop_(false) { +ThreadPool::ThreadPool(int workers) + : task_queues_(workers), + queue_mutexes_(workers), + queue_cvs_(workers), + set_mutexes_(workers), + output_sets_(workers), + stop_(false) { for (int i = 0; i < workers; i++) { workers_.emplace_back(&ThreadPool::worker, this, i); } @@ -30,13 +36,15 @@ std::future ThreadPool::enqueue( const std::vector& inputs, const std::vector& outputs) { std::vector barriers; - for (int i = 0; i < output_sets_.size(); i++) { - std::lock_guard lock(set_mutexes_[i]); + if (!inputs.empty()) { + for (int i = 0; i < output_sets_.size(); i++) { + std::lock_guard lock(set_mutexes_[i]); - for (auto& a : inputs) { - if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) { - barriers.push_back(i); - break; + for (auto& a : inputs) { + if (output_sets_[i].find(a.id()) != output_sets_[i].end()) { + barriers.push_back(i); + break; + } } } } @@ -44,7 +52,9 @@ std::future ThreadPool::enqueue( // Case 1: Barriers is empty so try to add it to the smallest queue if (barriers.empty()) { auto min_queue = std::min_element( - task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) { + task_queues_.begin(), + task_queues_.end(), + [](const auto& left, const auto& right) { return left.size() < right.size(); }); int worker_idx = std::distance(task_queues_.begin(), min_queue); @@ -67,7 +77,7 @@ std::future ThreadPool::enqueue( // Case 3: We need to add a barrier before our task and add it to the // smallest queue of the barriers. auto min_queue = std::min_element( - barriers.begin(), barriers.end(), [this](auto left, auto right) { + barriers.begin(), barriers.end(), [this](int left, int right) { return task_queues_[left].size() < task_queues_[right].size(); }); int worker_idx = *min_queue; @@ -109,7 +119,7 @@ void ThreadPool::add_outputs_to_worker( std::lock_guard lock(set_mutexes_[worker_idx]); for (auto& a : outputs) { - output_sets_[worker_idx].insert(a.buffer().ptr()); + output_sets_[worker_idx].insert(a.id()); } } @@ -121,20 +131,20 @@ std::function ThreadPool::remove_outputs_when_done( return task; } - std::vector output_buffers; + std::vector output_ids; for (auto& a : outputs) { - output_buffers.push_back(a.buffer().ptr()); + output_ids.push_back(a.id()); } return [og_task = std::move(task), - buffers = std::move(output_buffers), + ids = std::move(output_ids), worker_idx, this]() { og_task(); { std::lock_guard lock(set_mutexes_[worker_idx]); - for (auto b : buffers) { - output_sets_[worker_idx].erase(b); + for (auto id : ids) { + output_sets_[worker_idx].erase(id); } } }; @@ -195,7 +205,11 @@ void ThreadPool::worker(int idx) { task = std::move(task_queues_[idx].front()); task_queues_[idx].pop(); } - task(); + try { + task(); + } catch (...) { + // do nothing? + } } } diff --git a/mlx/backend/io/thread_pool.h b/mlx/backend/io/thread_pool.h index d1da25c4d..1dc5bf2c6 100644 --- a/mlx/backend/io/thread_pool.h +++ b/mlx/backend/io/thread_pool.h @@ -43,10 +43,10 @@ class ThreadPool { std::vector>> task_queues_; std::vector queue_mutexes_; std::vector queue_cvs_; - std::vector workers_; std::vector set_mutexes_; - std::vector> output_sets_; + std::vector> output_sets_; bool stop_; + std::vector workers_; }; } // namespace mlx::core::io::detail diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 767ff855d..790abbdb4 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -2,6 +2,7 @@ #include "mlx/scheduler.h" #include "mlx/backend/common/cpu_impl.h" +#include "mlx/backend/io/io_impl.h" #include "mlx/backend/metal/metal.h" namespace mlx::core { @@ -44,6 +45,9 @@ void synchronize(Stream s) { case mlx::core::Device::gpu: scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); break; + case mlx::core::Device::io: + scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p))); + break; } f.wait(); } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index d14dd4fd5..fb3525e8f 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -76,6 +76,7 @@ class Scheduler { default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); } default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); + default_streams_.insert({Device::io, new_stream(Device::io)}); } // Not copyable or moveable diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index bb06b1b23..db278292b 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -28,6 +28,7 @@ class Synchronizer : public Primitive { void eval_cpu(const std::vector&, std::vector&) override {}; void eval_gpu(const std::vector&, std::vector&) override {}; + void eval_io(const std::vector&, std::vector&) override {}; DEFINE_PRINT(Synchronize); };