diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 58d1521c9..246d6bcc5 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -24,10 +24,6 @@ void Event::wait() { } } -void Event::signal() { - static_cast(event_.get())->setSignaledValue(value()); -} - void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); @@ -42,7 +38,9 @@ void Event::wait(Stream stream) { void Event::signal(Stream stream) { if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [*this]() mutable { signal(); }); + scheduler::enqueue(stream, [*this]() mutable { + static_cast(event_.get())->setSignaledValue(value()); + }); } else { auto& d = metal::device(stream.device); d.end_encoding(stream.index); diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_metal/event.cpp index 692be20d2..6dde047ab 100644 --- a/mlx/backend/no_metal/event.cpp +++ b/mlx/backend/no_metal/event.cpp @@ -28,21 +28,19 @@ void Event::wait() { ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); } -void Event::signal() { - auto ec = static_cast(event_.get()); - { - std::lock_guard lk(ec->mtx); - ec->value = value(); - } - ec->cv.notify_all(); -} - void Event::wait(Stream stream) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } void Event::signal(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { signal(); }); + scheduler::enqueue(stream, [*this]() mutable { + auto ec = static_cast(event_.get()); + { + std::lock_guard lk(ec->mtx); + ec->value = value(); + } + ec->cv.notify_all(); + }); } bool Event::is_signaled() const { diff --git a/mlx/event.h b/mlx/event.h index 7054b60c7..937028e2a 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -16,9 +16,6 @@ class Event { // Wait for the event to be signaled at its current value void wait(); - // Signal the event at its current value - void signal(); - // Wait in the given stream for the event to be signaled at its current value void wait(Stream stream);