diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 2d9942eda..d20ba1ad3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,42 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; +class Interrupt { + private: + static std::mutex mutex_; + static bool eval_running_; + static bool interrupt_; + + public: + Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = true; + } + + static bool interrupt() { + std::unique_lock lk(mutex_); + if (eval_running_) { + interrupt_ = true; + return true; + } + return false; + } + + static bool interrupted() { + std::unique_lock lk(mutex_); + return interrupt_; + } + + ~Interrupt() { + std::unique_lock lk(mutex_); + eval_running_ = false; + interrupt_ = false; + } +}; +std::mutex Interrupt::mutex_{}; +bool Interrupt::eval_running_ = false; +bool Interrupt::interrupt_ = false; + // Initialize the static tracing members from transforms_impl.h // // These are used to implement the in_tracing() function the returns true if we @@ -50,6 +87,8 @@ int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { + Interrupt interrupt; + std::deque tape; // Make an effort to choose a good output stream @@ -260,6 +299,11 @@ array eval_impl(std::vector outputs, bool async) { if (!arr.is_tracer()) { arr.detach(); } + + if (Interrupt::interrupted()) { + synchronizer.attach_event(Event{stream}); + break; + } } // Signal the event in its stream @@ -274,6 +318,10 @@ array eval_impl(std::vector outputs, bool async) { return synchronizer; } +bool interrupt_eval() { + return Interrupt::interrupt(); +} + void async_eval(std::vector outputs) { if (outputs.empty()) { return; diff --git a/mlx/transforms.h b/mlx/transforms.h index 4afb21e23..2cde23f9b 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -22,6 +22,14 @@ void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); } +/** + * Interrupt an ongoing eval. + * + * Leaves the graph in a valid state. Returns true if an ongoing eval was + * interrupted and false otherwise. + */ +bool interrupt_eval(); + /** * Computes the output and vector-Jacobian product (VJP) of a function. * diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index ce4bc980f..46b5c97be 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include "doctest/doctest.h" #include "mlx/mlx.h" @@ -91,3 +90,17 @@ TEST_CASE("test eval graph retention when not tracing") { CHECK(!a.has_primitive()); CHECK(a.is_available()); } + +TEST_CASE("test interrupt eval") { + auto x = zeros({1024}, int32); + for (int i = 0; i < 1000; ++i) { + x = x + 1; + } + std::thread t([x]() { eval(x); }); + while (!interrupt_eval()) { + } + t.join(); + // Check that x is not evaluated + CHECK(!x.is_available()); + CHECK(array_equal(x, full({1024}, 1000, int32)).item()); +}