Add the io threadpool and task

This commit is contained in:
Angelos Katharopoulos 2024-05-08 15:45:44 -07:00
parent be36f136de
commit c8e2b42ced
7 changed files with 352 additions and 2 deletions

View File

@ -20,6 +20,7 @@ target_sources(
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/io)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

View File

@ -12,8 +12,7 @@ std::function<void()> make_task(array arr, bool signal) {
// Wait on inputs coming from different streams/devices.
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
if (input.event().valid() && input.event().stream() != stream) {
input.event().wait();
}
}

View File

@ -0,0 +1,6 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/io_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/thread_pool.cpp
)

View File

@ -0,0 +1,72 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/io/io_impl.h"
#include "mlx/backend/io/thread_pool.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::io {
namespace {
detail::ThreadPool& thread_pool() {
static std::unique_ptr<detail::ThreadPool> pool_ptr;
if (pool_ptr == nullptr) {
pool_ptr = std::make_unique<detail::ThreadPool>(4);
}
return *pool_ptr;
}
} // namespace
std::function<void()> make_task(array arr, bool signal) {
return [arr = std::move(arr), signal]() mutable {
auto stream = arr.primitive().stream();
// Wait on inputs coming from different streams/devices.
for (auto& input : arr.inputs()) {
if (input.event().valid() && input.event().stream() != stream) {
input.event().wait();
}
}
// Task computation actually starting.
scheduler::notify_new_task(stream);
// Schedule the computation
auto inputs = arr.inputs();
auto outputs = arr.outputs();
thread_pool().enqueue(
[arr = std::move(arr), inputs, outputs, signal, stream]() mutable {
// Perform the computation
arr.primitive().eval_io(inputs, outputs);
if (!arr.is_tracer()) {
arr.detach();
}
if (signal) {
thread_pool().barrier(
[arr = std::move(arr)]() { arr.event().signal(); });
}
// Task computation done.
scheduler::notify_task_completion(stream);
},
inputs,
outputs);
};
}
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [p = std::move(p)]() {
thread_pool().barrier().wait();
p->set_value();
};
}
} // namespace mlx::core::io

18
mlx/backend/io/io_impl.h Normal file
View File

@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include "mlx/array.h"
namespace mlx::core::io {
std::function<void()> make_task(array arr, bool signal);
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p);
} // namespace mlx::core::io

View File

@ -0,0 +1,202 @@
// Copyright © 2024 Apple Inc.
#include <numeric>
#include "mlx/backend/io/thread_pool.h"
namespace mlx::core::io::detail {
ThreadPool::ThreadPool(int workers) : stop_(false) {
for (int i = 0; i < workers; i++) {
workers_.emplace_back(&ThreadPool::worker, this, i);
}
}
ThreadPool::~ThreadPool() {
stop_ = true;
for (auto& cv : queue_cvs_) {
cv.notify_one();
}
for (auto& t : workers_) {
if (t.joinable()) {
t.join();
}
}
}
std::future<void> ThreadPool::enqueue(
std::function<void()> task,
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
std::vector<int> barriers;
for (int i = 0; i < output_sets_.size(); i++) {
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
for (auto& a : inputs) {
if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) {
barriers.push_back(i);
break;
}
}
}
// Case 1: Barriers is empty so try to add it to the smallest queue
if (barriers.empty()) {
auto min_queue = std::min_element(
task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) {
return left.size() < right.size();
});
int worker_idx = std::distance(task_queues_.begin(), min_queue);
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(std::move(task), outputs, worker_idx),
worker_idx);
}
// Case 2: Barriers has only one queue so put that into that queue
if (barriers.size() == 1) {
int worker_idx = barriers[0];
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(std::move(task), outputs, worker_idx),
worker_idx);
}
// Case 3: We need to add a barrier before our task and add it to the
// smallest queue of the barriers.
auto min_queue = std::min_element(
barriers.begin(), barriers.end(), [this](auto left, auto right) {
return task_queues_[left].size() < task_queues_[right].size();
});
int worker_idx = *min_queue;
barriers.erase(min_queue);
std::shared_future<void> queue_barrier =
barrier(barriers); // We shouldn't need shared future here
add_outputs_to_worker(outputs, worker_idx);
return enqueue(
remove_outputs_when_done(
[queue_barrier = std::move(queue_barrier),
og_task = std::move(task)]() {
queue_barrier.wait();
og_task();
},
outputs,
worker_idx),
worker_idx);
}
std::future<void> ThreadPool::enqueue(
std::function<void()> task,
int worker_idx) {
std::packaged_task<void()> pt(std::move(task));
std::future<void> result = pt.get_future();
{
std::lock_guard<std::mutex> lock(queue_mutexes_[worker_idx]);
task_queues_[worker_idx].emplace(std::move(pt));
}
queue_cvs_[worker_idx].notify_one();
return result;
}
void ThreadPool::add_outputs_to_worker(
const std::vector<array>& outputs,
int worker_idx) {
if (outputs.size() == 0) {
return;
}
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto& a : outputs) {
output_sets_[worker_idx].insert(a.buffer().ptr());
}
}
std::function<void()> ThreadPool::remove_outputs_when_done(
std::function<void()> task,
const std::vector<array>& outputs,
int worker_idx) {
if (outputs.size() == 0) {
return task;
}
std::vector<const void*> output_buffers;
for (auto& a : outputs) {
output_buffers.push_back(a.buffer().ptr());
}
return [og_task = std::move(task),
buffers = std::move(output_buffers),
worker_idx,
this]() {
og_task();
{
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
for (auto b : buffers) {
output_sets_[worker_idx].erase(b);
}
}
};
}
std::future<void> ThreadPool::barrier(
const std::vector<int>& worker_ids,
std::function<void()> on_barrier) {
auto workers = std::make_shared<std::atomic<int>>(worker_ids.size());
auto promise = std::make_shared<std::promise<void>>();
auto future = promise->get_future();
for (auto idx : worker_ids) {
enqueue(
[workers, promise, on_barrier = std::move(on_barrier)]() {
(*workers)--;
if (*workers <= 0) {
on_barrier();
promise->set_value();
}
},
idx);
}
return future;
}
std::future<void> ThreadPool::barrier(const std::vector<int>& worker_ids) {
auto noop = []() {};
return barrier(worker_ids, std::move(noop));
}
std::future<void> ThreadPool::barrier(std::function<void()> on_barrier) {
std::vector<int> worker_ids(workers_.size());
std::iota(worker_ids.begin(), worker_ids.end(), 0);
return barrier(worker_ids, std::move(on_barrier));
}
std::future<void> ThreadPool::barrier() {
auto noop = []() {};
return barrier(std::move(noop));
}
void ThreadPool::worker(int idx) {
while (true) {
std::packaged_task<void()> task;
{
std::unique_lock<std::mutex> lock(queue_mutexes_[idx]);
queue_cvs_[idx].wait(
lock, [this, idx]() { return stop_ || !task_queues_[idx].empty(); });
if (task_queues_[idx].empty()) {
if (stop_) {
break;
} else {
continue;
}
}
task = std::move(task_queues_[idx].front());
task_queues_[idx].pop();
}
task();
}
}
} // namespace mlx::core::io::detail

View File

@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <future>
#include <queue>
#include <unordered_set>
#include "mlx/array.h"
namespace mlx::core::io::detail {
class ThreadPool {
public:
explicit ThreadPool(int workers);
~ThreadPool();
ThreadPool(ThreadPool&&) = delete;
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(ThreadPool&&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
std::future<void> enqueue(
std::function<void()> task,
const std::vector<array>& inputs,
const std::vector<array>& outputs);
std::future<void> barrier(
const std::vector<int>& worker_ids,
std::function<void()> on_barrier);
std::future<void> barrier(const std::vector<int>& worker_ids);
std::future<void> barrier(std::function<void()> on_barrier);
std::future<void> barrier();
private:
std::future<void> enqueue(std::function<void()> task, int worker_idx);
void add_outputs_to_worker(const std::vector<array>& outputs, int worker_idx);
std::function<void()> remove_outputs_when_done(
std::function<void()> task,
const std::vector<array>& outputs,
int worker_idx);
void worker(int idx);
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
std::vector<std::mutex> queue_mutexes_;
std::vector<std::condition_variable> queue_cvs_;
std::vector<std::thread> workers_;
std::vector<std::mutex> set_mutexes_;
std::vector<std::unordered_set<const void*>> output_sets_;
bool stop_;
};
} // namespace mlx::core::io::detail