diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index ee5c664e6..e6718b4df 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -22,18 +22,15 @@ void Worker::add_task(std::function task) { pending_tasks_.push_back(std::move(task)); } -void signal_worker(void* data) { +void Worker::signal(void* data) { auto w = static_cast(data); - w->signal_(); + { + std::lock_guard lock(w->mtx_); + w->signaled_batch_++; + } + w->cond_.notify_one(); } -void Worker::signal_() { - { - std::lock_guard lock(mtx_); - signaled_batch_++; - } - cond_.notify_one(); -} void Worker::commit(cudaStream_t stream) { // Move pending tasks into tasks @@ -47,7 +44,7 @@ void Worker::commit(cudaStream_t stream) { } signal_event_.record(stream); signal_event_.wait(signal_stream_); - cudaLaunchHostFunc(signal_stream_, signal_worker, this); + cudaLaunchHostFunc(signal_stream_, signal, this); } void Worker::thread_fn() { diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h index 9e6b7b5f4..df6647e2b 100644 --- a/mlx/backend/cuda/worker.h +++ b/mlx/backend/cuda/worker.h @@ -13,8 +13,6 @@ namespace mlx::core::cu { -void signal_worker(void* data); - // Run tasks in worker thread, synchronized with cuda stream. class Worker { public: @@ -32,9 +30,8 @@ class Worker { void commit(cudaStream_t stream); private: - friend void signal_worker(void*); + static void signal(void*); - void signal_(); void thread_fn(); std::mutex mtx_; std::condition_variable cond_;