[CUDA] Simplify allocator (#2392)

* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
This commit is contained in:
Awni Hannun
2025-07-22 08:24:01 -07:00
committed by GitHub
parent 74eccbf3fa
commit 1e496ddb82
9 changed files with 100 additions and 162 deletions

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
@@ -12,10 +11,10 @@ Worker::Worker()
Worker::~Worker() {
{
std::lock_guard lock(worker_mutex_);
std::lock_guard lock(mtx_);
stop_ = true;
}
worker_event_.signal(batch_ + 1);
cond_.notify_one();
worker_.join();
}
@@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
void Worker::signal(void* data) {
auto w = static_cast<Worker*>(data);
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
std::lock_guard lock(w->mtx_);
w->signaled_batch_++;
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
w->cond_.notify_one();
}
void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) {
// Move pending tasks into tasks
if (pending_tasks_.empty()) {
return;
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
{
std::lock_guard lock(mtx_);
// Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_);
cudaLaunchHostFunc(signal_stream_, signal, this);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t batch = worker_event_.value();
uint64_t current_batch = 0;
Tasks tasks;
{
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
std::unique_lock<std::mutex> lk(mtx_);
cond_.wait(lk, [this, &current_batch] {
return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
@@ -85,7 +72,6 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]);
task();
}
worker_event_.wait(batch + 1);
}
}