Working IO primitives

This commit is contained in:
Angelos Katharopoulos 2024-05-08 22:17:25 -07:00
parent b193741050
commit bae159738f
5 changed files with 38 additions and 18 deletions

View File

@ -6,7 +6,13 @@
namespace mlx::core::io::detail {
ThreadPool::ThreadPool(int workers) : stop_(false) {
ThreadPool::ThreadPool(int workers)
: task_queues_(workers),
queue_mutexes_(workers),
queue_cvs_(workers),
set_mutexes_(workers),
output_sets_(workers),
stop_(false) {
for (int i = 0; i < workers; i++) {
workers_.emplace_back(&ThreadPool::worker, this, i);
}
@ -30,13 +36,15 @@ std::future<void> ThreadPool::enqueue(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
std::vector<int> barriers;
for (int i = 0; i < output_sets_.size(); i++) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
if (!inputs.empty()) {
for (int i = 0; i < output_sets_.size(); i++) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
for (auto& a : inputs) {
if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) {
barriers.push_back(i);
break;
for (auto& a : inputs) {
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
barriers.push_back(i);
break;
}
}
}
}
@ -44,7 +52,9 @@ std::future<void> ThreadPool::enqueue(
// Case 1: Barriers is empty so try to add it to the smallest queue
if (barriers.empty()) {
auto min_queue = std::min_element(
task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) {
task_queues_.begin(),
task_queues_.end(),
[](const auto& left, const auto& right) {
return left.size() < right.size();
});
int worker_idx = std::distance(task_queues_.begin(), min_queue);
@ -67,7 +77,7 @@ std::future<void> ThreadPool::enqueue(
// Case 3: We need to add a barrier before our task and add it to the
// smallest queue of the barriers.
auto min_queue = std::min_element(
barriers.begin(), barriers.end(), [this](auto left, auto right) {
barriers.begin(), barriers.end(), [this](int left, int right) {
return task_queues_[left].size() < task_queues_[right].size();
});
int worker_idx = *min_queue;
@ -109,7 +119,7 @@ void ThreadPool::add_outputs_to_worker(
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto& a : outputs) {
output_sets_[worker_idx].insert(a.buffer().ptr());
output_sets_[worker_idx].insert(a.id());
}
}
@ -121,20 +131,20 @@ std::function<void()> ThreadPool::remove_outputs_when_done(
return task;
}
std::vector<const void*> output_buffers;
std::vector<std::uintptr_t> output_ids;
for (auto& a : outputs) {
output_buffers.push_back(a.buffer().ptr());
output_ids.push_back(a.id());
}
return [og_task = std::move(task),
buffers = std::move(output_buffers),
ids = std::move(output_ids),
worker_idx,
this]() {
og_task();
{
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto b : buffers) {
output_sets_[worker_idx].erase(b);
for (auto id : ids) {
output_sets_[worker_idx].erase(id);
}
}
};
@ -195,7 +205,11 @@ void ThreadPool::worker(int idx) {
task = std::move(task_queues_[idx].front());
task_queues_[idx].pop();
}
task();
try {
task();
} catch (...) {
// do nothing?
}
}
}

View File

@ -43,10 +43,10 @@ class ThreadPool {
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
std::vector<std::mutex> queue_mutexes_;
std::vector<std::condition_variable> queue_cvs_;
std::vector<std::thread> workers_;
std::vector<std::mutex> set_mutexes_;
std::vector<std::unordered_set<const void*>> output_sets_;
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
bool stop_;
std::vector<std::thread> workers_;
};
} // namespace mlx::core::io::detail

View File

@ -2,6 +2,7 @@
#include "mlx/scheduler.h"
#include "mlx/backend/common/cpu_impl.h"
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
@ -44,6 +45,9 @@ void synchronize(Stream s) {
case mlx::core::Device::gpu:
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
break;
case mlx::core::Device::io:
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
break;
}
f.wait();
}

View File

@ -76,6 +76,7 @@ class Scheduler {
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
}
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
default_streams_.insert({Device::io, new_stream(Device::io)});
}
// Not copyable or moveable

View File

@ -28,6 +28,7 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
DEFINE_PRINT(Synchronize);
};