diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 5949a40bb..567d0567c 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -35,10 +36,41 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; -std::atomic& interrupt_flag() { - static std::atomic interrupt_{false}; - return interrupt_; -} +class Interrupt { + private: + static std::mutex mutex_; + static bool eval_running_; + static bool interrupt_; + + public: + Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = true; + } + + static bool interrupt() { + std::unique_lock lk(mutex_); + if (eval_running_) { + interrupt_ = true; + return true; + } + return false; + } + + static bool interrupted() { + std::unique_lock lk(mutex_); + return interrupt_; + } + + ~Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = false; + interrupt_ = false; + } +}; +std::mutex Interrupt::mutex_{}; +bool Interrupt::eval_running_ = false; +bool Interrupt::interrupt_ = false; // Initialize the static tracing members from transforms_impl.h // @@ -50,6 +82,8 @@ std::vector detail::InTracing::trace_stack{}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { + Interrupt interrupt; + std::deque tape; // Make an effort to choose a good output stream @@ -255,8 +289,7 @@ array eval_impl(std::vector outputs, bool async) { arr.detach(); } - if (interrupt_flag()) { - interrupt_flag() = false; + if (Interrupt::interrupted()) { synchronizer.attach_event(Event{stream}); break; } @@ -274,8 +307,8 @@ array eval_impl(std::vector outputs, bool async) { return synchronizer; } -void interrupt_eval() { - interrupt_flag() = true; +bool interrupt_eval() { + return Interrupt::interrupt(); } void async_eval(std::vector outputs) { diff --git a/mlx/transforms.h b/mlx/transforms.h index 89b37c720..2cde23f9b 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -23,9 +23,12 @@ void eval(Arrays&&... outputs) { } /** - * Interrupt an ongoing eval. Leaves the graph in a valid state. + * Interrupt an ongoing eval. + * + * Leaves the graph in a valid state. Returns true if an ongoing eval was + * interrupted and false otherwise. */ -void interrupt_eval(); +bool interrupt_eval(); /** * Computes the output and vector-Jacobian product (VJP) of a function. diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index cd1b5dcaa..46b5c97be 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -97,7 +97,8 @@ TEST_CASE("test interrupt eval") { x = x + 1; } std::thread t([x]() { eval(x); }); - interrupt_eval(); + while (!interrupt_eval()) { + } t.join(); // Check that x is not evaluated CHECK(!x.is_available());