From 9ffe88841c86fd0c9845eee95640b59eeb837730 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 18 Mar 2025 17:23:31 -0700 Subject: [PATCH] interruptable eval --- mlx/transforms.cpp | 15 +++++++++++++++ mlx/transforms.h | 5 +++++ tests/eval_tests.cpp | 14 +++++++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f01082418..5949a40bb 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -35,6 +35,11 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; +std::atomic& interrupt_flag() { + static std::atomic interrupt_{false}; + return interrupt_; +} + // Initialize the static tracing members from transforms_impl.h // // These are used to implement the in_tracing() function the returns true if we @@ -249,6 +254,12 @@ array eval_impl(std::vector outputs, bool async) { if (!arr.is_tracer()) { arr.detach(); } + + if (interrupt_flag()) { + interrupt_flag() = false; + synchronizer.attach_event(Event{stream}); + break; + } } // Signal the event in its stream @@ -263,6 +274,10 @@ array eval_impl(std::vector outputs, bool async) { return synchronizer; } +void interrupt_eval() { + interrupt_flag() = true; +} + void async_eval(std::vector outputs) { if (outputs.empty()) { return; diff --git a/mlx/transforms.h b/mlx/transforms.h index 4afb21e23..89b37c720 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -22,6 +22,11 @@ void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); } +/** + * Interrupt an ongoing eval. Leaves the graph in a valid state. + */ +void interrupt_eval(); + /** * Computes the output and vector-Jacobian product (VJP) of a function. * diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index ce4bc980f..cd1b5dcaa 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include "doctest/doctest.h" #include "mlx/mlx.h" @@ -91,3 +90,16 @@ TEST_CASE("test eval graph retention when not tracing") { CHECK(!a.has_primitive()); CHECK(a.is_available()); } + +TEST_CASE("test interrupt eval") { + auto x = zeros({1024}, int32); + for (int i = 0; i < 1000; ++i) { + x = x + 1; + } + std::thread t([x]() { eval(x); }); + interrupt_eval(); + t.join(); + // Check that x is not evaluated + CHECK(!x.is_available()); + CHECK(array_equal(x, full({1024}, 1000, int32)).item()); +}