This commit is contained in:
Awni Hannun
2025-07-22 07:18:28 -07:00
parent 4fd39d662d
commit b1a44ef240
2 changed files with 8 additions and 14 deletions

View File

@@ -22,18 +22,15 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void signal_worker(void* data) { void Worker::signal(void* data) {
auto w = static_cast<Worker*>(data); auto w = static_cast<Worker*>(data);
w->signal_(); {
std::lock_guard lock(w->mtx_);
w->signaled_batch_++;
}
w->cond_.notify_one();
} }
void Worker::signal_() {
{
std::lock_guard lock(mtx_);
signaled_batch_++;
}
cond_.notify_one();
}
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
// Move pending tasks into tasks // Move pending tasks into tasks
@@ -47,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
} }
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
cudaLaunchHostFunc(signal_stream_, signal_worker, this); cudaLaunchHostFunc(signal_stream_, signal, this);
} }
void Worker::thread_fn() { void Worker::thread_fn() {

View File

@@ -13,8 +13,6 @@
namespace mlx::core::cu { namespace mlx::core::cu {
void signal_worker(void* data);
// Run tasks in worker thread, synchronized with cuda stream. // Run tasks in worker thread, synchronized with cuda stream.
class Worker { class Worker {
public: public:
@@ -32,9 +30,8 @@ class Worker {
void commit(cudaStream_t stream); void commit(cudaStream_t stream);
private: private:
friend void signal_worker(void*); static void signal(void*);
void signal_();
void thread_fn(); void thread_fn();
std::mutex mtx_; std::mutex mtx_;
std::condition_variable cond_; std::condition_variable cond_;