From 9ffe88841c86fd0c9845eee95640b59eeb837730 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 18 Mar 2025 17:23:31 -0700 Subject: [PATCH 1/2] interruptable eval --- mlx/transforms.cpp | 15 +++++++++++++++ mlx/transforms.h | 5 +++++ tests/eval_tests.cpp | 14 +++++++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f01082418..5949a40bb 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -35,6 +35,11 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; +std::atomic& interrupt_flag() { + static std::atomic interrupt_{false}; + return interrupt_; +} + // Initialize the static tracing members from transforms_impl.h // // These are used to implement the in_tracing() function the returns true if we @@ -249,6 +254,12 @@ array eval_impl(std::vector outputs, bool async) { if (!arr.is_tracer()) { arr.detach(); } + + if (interrupt_flag()) { + interrupt_flag() = false; + synchronizer.attach_event(Event{stream}); + break; + } } // Signal the event in its stream @@ -263,6 +274,10 @@ array eval_impl(std::vector outputs, bool async) { return synchronizer; } +void interrupt_eval() { + interrupt_flag() = true; +} + void async_eval(std::vector outputs) { if (outputs.empty()) { return; diff --git a/mlx/transforms.h b/mlx/transforms.h index 4afb21e23..89b37c720 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -22,6 +22,11 @@ void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); } +/** + * Interrupt an ongoing eval. Leaves the graph in a valid state. + */ +void interrupt_eval(); + /** * Computes the output and vector-Jacobian product (VJP) of a function. * diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index ce4bc980f..cd1b5dcaa 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include "doctest/doctest.h" #include "mlx/mlx.h" @@ -91,3 +90,16 @@ TEST_CASE("test eval graph retention when not tracing") { CHECK(!a.has_primitive()); CHECK(a.is_available()); } + +TEST_CASE("test interrupt eval") { + auto x = zeros({1024}, int32); + for (int i = 0; i < 1000; ++i) { + x = x + 1; + } + std::thread t([x]() { eval(x); }); + interrupt_eval(); + t.join(); + // Check that x is not evaluated + CHECK(!x.is_available()); + CHECK(array_equal(x, full({1024}, 1000, int32)).item()); +} From 688e421184135b585527aa58969a1dc9c7bf74e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 19 Mar 2025 07:56:26 -0700 Subject: [PATCH 2/2] only interrupt during an eval --- mlx/transforms.cpp | 49 ++++++++++++++++++++++++++++++++++++-------- mlx/transforms.h | 7 +++++-- tests/eval_tests.cpp | 3 ++- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 5949a40bb..567d0567c 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -35,10 +36,41 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; -std::atomic& interrupt_flag() { - static std::atomic interrupt_{false}; - return interrupt_; -} +class Interrupt { + private: + static std::mutex mutex_; + static bool eval_running_; + static bool interrupt_; + + public: + Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = true; + } + + static bool interrupt() { + std::unique_lock lk(mutex_); + if (eval_running_) { + interrupt_ = true; + return true; + } + return false; + } + + static bool interrupted() { + std::unique_lock lk(mutex_); + return interrupt_; + } + + ~Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = false; + interrupt_ = false; + } +}; +std::mutex Interrupt::mutex_{}; +bool Interrupt::eval_running_ = false; +bool Interrupt::interrupt_ = false; // Initialize the static tracing members from transforms_impl.h // @@ -50,6 +82,8 @@ std::vector detail::InTracing::trace_stack{}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { + Interrupt interrupt; + std::deque tape; // Make an effort to choose a good output stream @@ -255,8 +289,7 @@ array eval_impl(std::vector outputs, bool async) { arr.detach(); } - if (interrupt_flag()) { - interrupt_flag() = false; + if (Interrupt::interrupted()) { synchronizer.attach_event(Event{stream}); break; } @@ -274,8 +307,8 @@ array eval_impl(std::vector outputs, bool async) { return synchronizer; } -void interrupt_eval() { - interrupt_flag() = true; +bool interrupt_eval() { + return Interrupt::interrupt(); } void async_eval(std::vector outputs) { diff --git a/mlx/transforms.h b/mlx/transforms.h index 89b37c720..2cde23f9b 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -23,9 +23,12 @@ void eval(Arrays&&... outputs) { } /** - * Interrupt an ongoing eval. Leaves the graph in a valid state. + * Interrupt an ongoing eval. + * + * Leaves the graph in a valid state. Returns true if an ongoing eval was + * interrupted and false otherwise. */ -void interrupt_eval(); +bool interrupt_eval(); /** * Computes the output and vector-Jacobian product (VJP) of a function. diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index cd1b5dcaa..46b5c97be 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -97,7 +97,8 @@ TEST_CASE("test interrupt eval") { x = x + 1; } std::thread t([x]() { eval(x); }); - interrupt_eval(); + while (!interrupt_eval()) { + } t.join(); // Check that x is not evaluated CHECK(!x.is_available());