Working IO primitives

This commit is contained in:
Angelos Katharopoulos 2024-05-08 22:17:25 -07:00
parent b193741050
commit bae159738f
5 changed files with 38 additions and 18 deletions

View File

@ -6,7 +6,13 @@
namespace mlx::core::io::detail { namespace mlx::core::io::detail {
ThreadPool::ThreadPool(int workers) : stop_(false) { ThreadPool::ThreadPool(int workers)
: task_queues_(workers),
queue_mutexes_(workers),
queue_cvs_(workers),
set_mutexes_(workers),
output_sets_(workers),
stop_(false) {
for (int i = 0; i < workers; i++) { for (int i = 0; i < workers; i++) {
workers_.emplace_back(&ThreadPool::worker, this, i); workers_.emplace_back(&ThreadPool::worker, this, i);
} }
@ -30,13 +36,15 @@ std::future<void> ThreadPool::enqueue(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs) { const std::vector<array>& outputs) {
std::vector<int> barriers; std::vector<int> barriers;
for (int i = 0; i < output_sets_.size(); i++) { if (!inputs.empty()) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]); for (int i = 0; i < output_sets_.size(); i++) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
for (auto& a : inputs) { for (auto& a : inputs) {
if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) { if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
barriers.push_back(i); barriers.push_back(i);
break; break;
}
} }
} }
} }
@ -44,7 +52,9 @@ std::future<void> ThreadPool::enqueue(
// Case 1: Barriers is empty so try to add it to the smallest queue // Case 1: Barriers is empty so try to add it to the smallest queue
if (barriers.empty()) { if (barriers.empty()) {
auto min_queue = std::min_element( auto min_queue = std::min_element(
task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) { task_queues_.begin(),
task_queues_.end(),
[](const auto& left, const auto& right) {
return left.size() < right.size(); return left.size() < right.size();
}); });
int worker_idx = std::distance(task_queues_.begin(), min_queue); int worker_idx = std::distance(task_queues_.begin(), min_queue);
@ -67,7 +77,7 @@ std::future<void> ThreadPool::enqueue(
// Case 3: We need to add a barrier before our task and add it to the // Case 3: We need to add a barrier before our task and add it to the
// smallest queue of the barriers. // smallest queue of the barriers.
auto min_queue = std::min_element( auto min_queue = std::min_element(
barriers.begin(), barriers.end(), [this](auto left, auto right) { barriers.begin(), barriers.end(), [this](int left, int right) {
return task_queues_[left].size() < task_queues_[right].size(); return task_queues_[left].size() < task_queues_[right].size();
}); });
int worker_idx = *min_queue; int worker_idx = *min_queue;
@ -109,7 +119,7 @@ void ThreadPool::add_outputs_to_worker(
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]); std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto& a : outputs) { for (auto& a : outputs) {
output_sets_[worker_idx].insert(a.buffer().ptr()); output_sets_[worker_idx].insert(a.id());
} }
} }
@ -121,20 +131,20 @@ std::function<void()> ThreadPool::remove_outputs_when_done(
return task; return task;
} }
std::vector<const void*> output_buffers; std::vector<std::uintptr_t> output_ids;
for (auto& a : outputs) { for (auto& a : outputs) {
output_buffers.push_back(a.buffer().ptr()); output_ids.push_back(a.id());
} }
return [og_task = std::move(task), return [og_task = std::move(task),
buffers = std::move(output_buffers), ids = std::move(output_ids),
worker_idx, worker_idx,
this]() { this]() {
og_task(); og_task();
{ {
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]); std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto b : buffers) { for (auto id : ids) {
output_sets_[worker_idx].erase(b); output_sets_[worker_idx].erase(id);
} }
} }
}; };
@ -195,7 +205,11 @@ void ThreadPool::worker(int idx) {
task = std::move(task_queues_[idx].front()); task = std::move(task_queues_[idx].front());
task_queues_[idx].pop(); task_queues_[idx].pop();
} }
task(); try {
task();
} catch (...) {
// do nothing?
}
} }
} }

View File

@ -43,10 +43,10 @@ class ThreadPool {
std::vector<std::queue<std::packaged_task<void()>>> task_queues_; std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
std::vector<std::mutex> queue_mutexes_; std::vector<std::mutex> queue_mutexes_;
std::vector<std::condition_variable> queue_cvs_; std::vector<std::condition_variable> queue_cvs_;
std::vector<std::thread> workers_;
std::vector<std::mutex> set_mutexes_; std::vector<std::mutex> set_mutexes_;
std::vector<std::unordered_set<const void*>> output_sets_; std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
bool stop_; bool stop_;
std::vector<std::thread> workers_;
}; };
} // namespace mlx::core::io::detail } // namespace mlx::core::io::detail

View File

@ -2,6 +2,7 @@
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/backend/common/cpu_impl.h" #include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
namespace mlx::core { namespace mlx::core {
@ -44,6 +45,9 @@ void synchronize(Stream s) {
case mlx::core::Device::gpu: case mlx::core::Device::gpu:
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
break; break;
case mlx::core::Device::io:
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
break;
} }
f.wait(); f.wait();
} }

View File

@ -76,6 +76,7 @@ class Scheduler {
default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
} }
default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
default_streams_.insert({Device::io, new_stream(Device::io)});
} }
// Not copyable or moveable // Not copyable or moveable

View File

@ -28,6 +28,7 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}; void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}; void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
DEFINE_PRINT(Synchronize); DEFINE_PRINT(Synchronize);
}; };