mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Working IO primitives
This commit is contained in:
parent
b193741050
commit
bae159738f
@ -6,7 +6,13 @@
|
|||||||
|
|
||||||
namespace mlx::core::io::detail {
|
namespace mlx::core::io::detail {
|
||||||
|
|
||||||
ThreadPool::ThreadPool(int workers) : stop_(false) {
|
ThreadPool::ThreadPool(int workers)
|
||||||
|
: task_queues_(workers),
|
||||||
|
queue_mutexes_(workers),
|
||||||
|
queue_cvs_(workers),
|
||||||
|
set_mutexes_(workers),
|
||||||
|
output_sets_(workers),
|
||||||
|
stop_(false) {
|
||||||
for (int i = 0; i < workers; i++) {
|
for (int i = 0; i < workers; i++) {
|
||||||
workers_.emplace_back(&ThreadPool::worker, this, i);
|
workers_.emplace_back(&ThreadPool::worker, this, i);
|
||||||
}
|
}
|
||||||
@ -30,13 +36,15 @@ std::future<void> ThreadPool::enqueue(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
std::vector<int> barriers;
|
std::vector<int> barriers;
|
||||||
for (int i = 0; i < output_sets_.size(); i++) {
|
if (!inputs.empty()) {
|
||||||
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
for (int i = 0; i < output_sets_.size(); i++) {
|
||||||
|
std::lock_guard<std::mutex> lock(set_mutexes_[i]);
|
||||||
|
|
||||||
for (auto& a : inputs) {
|
for (auto& a : inputs) {
|
||||||
if (output_sets_[i].find(a.buffer().ptr()) != output_sets_[i].end()) {
|
if (output_sets_[i].find(a.id()) != output_sets_[i].end()) {
|
||||||
barriers.push_back(i);
|
barriers.push_back(i);
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -44,7 +52,9 @@ std::future<void> ThreadPool::enqueue(
|
|||||||
// Case 1: Barriers is empty so try to add it to the smallest queue
|
// Case 1: Barriers is empty so try to add it to the smallest queue
|
||||||
if (barriers.empty()) {
|
if (barriers.empty()) {
|
||||||
auto min_queue = std::min_element(
|
auto min_queue = std::min_element(
|
||||||
task_queues_.begin(), task_queues_.end(), [](auto& left, auto& right) {
|
task_queues_.begin(),
|
||||||
|
task_queues_.end(),
|
||||||
|
[](const auto& left, const auto& right) {
|
||||||
return left.size() < right.size();
|
return left.size() < right.size();
|
||||||
});
|
});
|
||||||
int worker_idx = std::distance(task_queues_.begin(), min_queue);
|
int worker_idx = std::distance(task_queues_.begin(), min_queue);
|
||||||
@ -67,7 +77,7 @@ std::future<void> ThreadPool::enqueue(
|
|||||||
// Case 3: We need to add a barrier before our task and add it to the
|
// Case 3: We need to add a barrier before our task and add it to the
|
||||||
// smallest queue of the barriers.
|
// smallest queue of the barriers.
|
||||||
auto min_queue = std::min_element(
|
auto min_queue = std::min_element(
|
||||||
barriers.begin(), barriers.end(), [this](auto left, auto right) {
|
barriers.begin(), barriers.end(), [this](int left, int right) {
|
||||||
return task_queues_[left].size() < task_queues_[right].size();
|
return task_queues_[left].size() < task_queues_[right].size();
|
||||||
});
|
});
|
||||||
int worker_idx = *min_queue;
|
int worker_idx = *min_queue;
|
||||||
@ -109,7 +119,7 @@ void ThreadPool::add_outputs_to_worker(
|
|||||||
|
|
||||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||||
for (auto& a : outputs) {
|
for (auto& a : outputs) {
|
||||||
output_sets_[worker_idx].insert(a.buffer().ptr());
|
output_sets_[worker_idx].insert(a.id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,20 +131,20 @@ std::function<void()> ThreadPool::remove_outputs_when_done(
|
|||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const void*> output_buffers;
|
std::vector<std::uintptr_t> output_ids;
|
||||||
for (auto& a : outputs) {
|
for (auto& a : outputs) {
|
||||||
output_buffers.push_back(a.buffer().ptr());
|
output_ids.push_back(a.id());
|
||||||
}
|
}
|
||||||
|
|
||||||
return [og_task = std::move(task),
|
return [og_task = std::move(task),
|
||||||
buffers = std::move(output_buffers),
|
ids = std::move(output_ids),
|
||||||
worker_idx,
|
worker_idx,
|
||||||
this]() {
|
this]() {
|
||||||
og_task();
|
og_task();
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
std::lock_guard<std::mutex> lock(set_mutexes_[worker_idx]);
|
||||||
for (auto b : buffers) {
|
for (auto id : ids) {
|
||||||
output_sets_[worker_idx].erase(b);
|
output_sets_[worker_idx].erase(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -195,7 +205,11 @@ void ThreadPool::worker(int idx) {
|
|||||||
task = std::move(task_queues_[idx].front());
|
task = std::move(task_queues_[idx].front());
|
||||||
task_queues_[idx].pop();
|
task_queues_[idx].pop();
|
||||||
}
|
}
|
||||||
task();
|
try {
|
||||||
|
task();
|
||||||
|
} catch (...) {
|
||||||
|
// do nothing?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,10 +43,10 @@ class ThreadPool {
|
|||||||
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
|
std::vector<std::queue<std::packaged_task<void()>>> task_queues_;
|
||||||
std::vector<std::mutex> queue_mutexes_;
|
std::vector<std::mutex> queue_mutexes_;
|
||||||
std::vector<std::condition_variable> queue_cvs_;
|
std::vector<std::condition_variable> queue_cvs_;
|
||||||
std::vector<std::thread> workers_;
|
|
||||||
std::vector<std::mutex> set_mutexes_;
|
std::vector<std::mutex> set_mutexes_;
|
||||||
std::vector<std::unordered_set<const void*>> output_sets_;
|
std::vector<std::unordered_set<std::uintptr_t>> output_sets_;
|
||||||
bool stop_;
|
bool stop_;
|
||||||
|
std::vector<std::thread> workers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::io::detail
|
} // namespace mlx::core::io::detail
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/backend/common/cpu_impl.h"
|
#include "mlx/backend/common/cpu_impl.h"
|
||||||
|
#include "mlx/backend/io/io_impl.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -44,6 +45,9 @@ void synchronize(Stream s) {
|
|||||||
case mlx::core::Device::gpu:
|
case mlx::core::Device::gpu:
|
||||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||||
break;
|
break;
|
||||||
|
case mlx::core::Device::io:
|
||||||
|
scheduler::enqueue(s, io::make_synchronize_task(s, std::move(p)));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
@ -76,6 +76,7 @@ class Scheduler {
|
|||||||
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
||||||
}
|
}
|
||||||
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
||||||
|
default_streams_.insert({Device::io, new_stream(Device::io)});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not copyable or moveable
|
// Not copyable or moveable
|
||||||
|
@ -28,6 +28,7 @@ class Synchronizer : public Primitive {
|
|||||||
|
|
||||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||||
|
void eval_io(const std::vector<array>&, std::vector<array>&) override {};
|
||||||
|
|
||||||
DEFINE_PRINT(Synchronize);
|
DEFINE_PRINT(Synchronize);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user