mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Working IO primitives
This commit is contained in:
parent
b193741050
commit
bae159738f
@ -6,7 +6,13 @@
|
||||
|
||||
namespace mlx::core::io::detail {
|
||||
|
||||
ThreadPool::ThreadPool(int workers) : stop_(false) {
|
||||
ThreadPool::ThreadPool(int workers)
|
||||
: task_queues_(workers),
|
||||
queue_mutexes_(workers),
|
||||
queue_cvs_(workers),
|
||||
set_mutexes_(workers),
|
||||
output_sets_(workers),
|
||||
stop_(false) {
|
||||
for (int i = 0; i < workers; i++) {
|
||||
workers_.emplace_back(&ThreadPool::worker, this, i);
|
||||
}
|
||||
@ -30,13 +36,15 @@ std::future<void> ThreadPool::enqueue(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<int> barriers;
|
||||
for (int i = 0; i < output_sets_.size(); i++) {
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
||||
if (!inputs.empty()) {
|
||||
for (int i = 0; i < output_sets_.size(); i++) {
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
||||
|
||||
for (auto& a : inputs) {
|
||||
if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) {
|
||||
barriers.push_back(i);
|
||||
break;
|
||||
for (auto& a : inputs) {
|
||||
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
|
||||
barriers.push_back(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -44,7 +52,9 @@ std::future<void> ThreadPool::enqueue(
|
||||
// Case 1: Barriers is empty so try to add it to the smallest queue
|
||||
if (barriers.empty()) {
|
||||
auto min_queue = std::min_element(
|
||||
task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) {
|
||||
task_queues_.begin(),
|
||||
task_queues_.end(),
|
||||
[](const auto& left, const auto& right) {
|
||||
return left.size() < right.size();
|
||||
});
|
||||
int worker_idx = std::distance(task_queues_.begin(), min_queue);
|
||||
@ -67,7 +77,7 @@ std::future<void> ThreadPool::enqueue(
|
||||
// Case 3: We need to add a barrier before our task and add it to the
|
||||
// smallest queue of the barriers.
|
||||
auto min_queue = std::min_element(
|
||||
barriers.begin(), barriers.end(), [this](auto left, auto right) {
|
||||
barriers.begin(), barriers.end(), [this](int left, int right) {
|
||||
return task_queues_[left].size() < task_queues_[right].size();
|
||||
});
|
||||
int worker_idx = *min_queue;
|
||||
@ -109,7 +119,7 @@ void ThreadPool::add_outputs_to_worker(
|
||||
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto& a : outputs) {
|
||||
output_sets_[worker_idx].insert(a.buffer().ptr());
|
||||
output_sets_[worker_idx].insert(a.id());
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,20 +131,20 @@ std::function<void()> ThreadPool::remove_outputs_when_done(
|
||||
return task;
|
||||
}
|
||||
|
||||
std::vector<const void*> output_buffers;
|
||||
std::vector<std::uintptr_t> output_ids;
|
||||
for (auto& a : outputs) {
|
||||
output_buffers.push_back(a.buffer().ptr());
|
||||
output_ids.push_back(a.id());
|
||||
}
|
||||
|
||||
return [og_task = std::move(task),
|
||||
buffers = std::move(output_buffers),
|
||||
ids = std::move(output_ids),
|
||||
worker_idx,
|
||||
this]() {
|
||||
og_task();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||
for (auto b : buffers) {
|
||||
output_sets_[worker_idx].erase(b);
|
||||
for (auto id : ids) {
|
||||
output_sets_[worker_idx].erase(id);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -195,7 +205,11 @@ void ThreadPool::worker(int idx) {
|
||||
task = std::move(task_queues_[idx].front());
|
||||
task_queues_[idx].pop();
|
||||
}
|
||||
task();
|
||||
try {
|
||||
task();
|
||||
} catch (...) {
|
||||
// do nothing?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,10 +43,10 @@ class ThreadPool {
|
||||
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
|
||||
std::vector<std::mutex> queue_mutexes_;
|
||||
std::vector<std::condition_variable> queue_cvs_;
|
||||
std::vector<std::thread> workers_;
|
||||
std::vector<std::mutex> set_mutexes_;
|
||||
std::vector<std::unordered_set<const void*>> output_sets_;
|
||||
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
|
||||
bool stop_;
|
||||
std::vector<std::thread> workers_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::io::detail
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/backend/common/cpu_impl.h"
|
||||
#include "mlx/backend/io/io_impl.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -44,6 +45,9 @@ void synchronize(Stream s) {
|
||||
case mlx::core::Device::gpu:
|
||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||
break;
|
||||
case mlx::core::Device::io:
|
||||
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
|
||||
break;
|
||||
}
|
||||
f.wait();
|
||||
}
|
||||
|
@ -76,6 +76,7 @@ class Scheduler {
|
||||
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
||||
}
|
||||
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
||||
default_streams_.insert({Device::io, new_stream(Device::io)});
|
||||
}
|
||||
|
||||
// Not copyable or moveable
|
||||
|
@ -28,6 +28,7 @@ class Synchronizer : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user