This commit is contained in:
Awni Hannun
2025-07-22 07:18:28 -07:00
parent 4fd39d662d
commit b1a44ef240
2 changed files with 8 additions and 14 deletions

View File

@@ -22,18 +22,15 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void signal_worker(void* data) {
void Worker::signal(void* data) {
auto w = static_cast<Worker*>(data);
w->signal_();
{
std::lock_guard lock(w->mtx_);
w->signaled_batch_++;
}
w->cond_.notify_one();
}
void Worker::signal_() {
{
std::lock_guard lock(mtx_);
signaled_batch_++;
}
cond_.notify_one();
}
void Worker::commit(cudaStream_t stream) {
// Move pending tasks into tasks
@@ -47,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
}
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
cudaLaunchHostFunc(signal_stream_, signal_worker, this);
cudaLaunchHostFunc(signal_stream_, signal, this);
}
void Worker::thread_fn() {

View File

@@ -13,8 +13,6 @@
namespace mlx::core::cu {
void signal_worker(void* data);
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
@@ -32,9 +30,8 @@ class Worker {
void commit(cudaStream_t stream);
private:
friend void signal_worker(void*);
static void signal(void*);
void signal_();
void thread_fn();
std::mutex mtx_;
std::condition_variable cond_;