Don't use shared event in worker

This commit is contained in:
Awni Hannun
2025-07-19 13:35:57 -07:00
parent b62368f292
commit 60e20bedb6
4 changed files with 41 additions and 33 deletions

View File

@@ -30,6 +30,9 @@ SmallSizePool::SmallSizePool() {
reinterpret_cast<char*>(buffer_) + small_pool_size); reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_); next_free_ = reinterpret_cast<Block*>(buffer_);
CHECK_CUDA_ERROR(
cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto num_blocks = small_pool_size / small_block_size; auto num_blocks = small_pool_size / small_block_size;
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) { for (size_t i = 0; i < num_blocks - 1; ++i) {

View File

@@ -306,7 +306,6 @@ void CommandEncoder::commit() {
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_); worker_.commit(stream_);
} }

View File

@@ -11,10 +11,10 @@ Worker::Worker()
Worker::~Worker() { Worker::~Worker() {
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(mtx_);
stop_ = true; stop_ = true;
} }
worker_event_.signal(batch_ + 1); cond_.notify_one();
worker_.join(); worker_.join();
} }
@@ -22,35 +22,45 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void Worker::end_batch() { void signal_worker(void* data) {
batch_++; auto w = static_cast<Worker*>(data);
{ w->signal_();
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
} }
uncommited_batches_++;
void Worker::signal_() {
{
std::lock_guard lock(mtx_);
signaled_batch_++;
}
cond_.notify_one();
} }
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) { // Move pending tasks into tasks
if (pending_tasks_.empty()) {
return; return;
} }
uncommited_batches_ = 0; {
// Signal the |worker_event_| in |signal_stream_| after the kernels in std::lock_guard lock(mtx_);
// |stream_| finish running. // Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_); cudaLaunchHostFunc(signal_stream_, signal_worker, this);
} }
void Worker::thread_fn() { void Worker::thread_fn() {
while (!stop_) { while (!stop_) {
uint64_t batch = worker_event_.value(); uint64_t current_batch = 0;
Tasks tasks; Tasks tasks;
{ {
std::lock_guard lock(worker_mutex_); std::unique_lock<std::mutex> lk(mtx_);
// Move tasks in signaled batches. cond_.wait(lk, [this, &current_batch] {
auto end = worker_tasks_.upper_bound(batch); return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) { for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) { if (tasks.empty()) {
tasks = std::move(it->second); tasks = std::move(it->second);
@@ -66,7 +76,6 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]); auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1);
} }
} }

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
@@ -12,6 +13,8 @@
namespace mlx::core::cu { namespace mlx::core::cu {
void signal_worker(void* data);
// Run tasks in worker thread, synchronized with cuda stream. // Run tasks in worker thread, synchronized with cuda stream.
class Worker { class Worker {
public: public:
@@ -24,32 +27,25 @@ class Worker {
// Add a pending |task| that will run when consumed or commited. // Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task); void add_task(std::function<void()> task);
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches after kernels in |stream| // Inform worker thread to run current batches after kernels in |stream|
// finish running. // finish running.
void commit(cudaStream_t stream); void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private: private:
void thread_fn(); friend void signal_worker(void*);
uint64_t batch_{0}; void signal_();
size_t uncommited_batches_{0}; void thread_fn();
std::mutex mtx_;
std::condition_variable cond_;
uint64_t committed_batch_{0};
uint64_t signaled_batch_{0};
// Cuda stream and event for signaling kernel completion. // Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_; CudaStream signal_stream_;
CudaEvent signal_event_; CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false}; bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to // Tasks are put in |pending_tasks_| first, and then moved to
@@ -57,6 +53,7 @@ class Worker {
using Tasks = std::vector<std::function<void()>>; using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_; Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_; std::map<uint64_t, Tasks> worker_tasks_;
std::thread worker_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu