interruptable eval

This commit is contained in:
Awni Hannun 2025-03-18 17:23:31 -07:00
parent 0a9777aa5c
commit 9ffe88841c
3 changed files with 33 additions and 1 deletions

View File

@ -35,6 +35,11 @@ class Synchronizer : public Primitive {
DEFINE_PRINT(Synchronize); DEFINE_PRINT(Synchronize);
}; };
std::atomic<bool>& interrupt_flag() {
static std::atomic<bool> interrupt_{false};
return interrupt_;
}
// Initialize the static tracing members from transforms_impl.h // Initialize the static tracing members from transforms_impl.h
// //
// These are used to implement the in_tracing() function the returns true if we // These are used to implement the in_tracing() function the returns true if we
@ -249,6 +254,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
if (!arr.is_tracer()) { if (!arr.is_tracer()) {
arr.detach(); arr.detach();
} }
if (interrupt_flag()) {
interrupt_flag() = false;
synchronizer.attach_event(Event{stream});
break;
}
} }
// Signal the event in its stream // Signal the event in its stream
@ -263,6 +274,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
return synchronizer; return synchronizer;
} }
void interrupt_eval() {
interrupt_flag() = true;
}
void async_eval(std::vector<array> outputs) { void async_eval(std::vector<array> outputs) {
if (outputs.empty()) { if (outputs.empty()) {
return; return;

View File

@ -22,6 +22,11 @@ void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...}); eval(std::vector<array>{std::forward<Arrays>(outputs)...});
} }
/**
* Interrupt an ongoing eval. Leaves the graph in a valid state.
*/
void interrupt_eval();
/** /**
* Computes the output and vector-Jacobian product (VJP) of a function. * Computes the output and vector-Jacobian product (VJP) of a function.
* *

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include "doctest/doctest.h" #include "doctest/doctest.h"
#include "mlx/mlx.h" #include "mlx/mlx.h"
@ -91,3 +90,16 @@ TEST_CASE("test eval graph retention when not tracing") {
CHECK(!a.has_primitive()); CHECK(!a.has_primitive());
CHECK(a.is_available()); CHECK(a.is_available());
} }
TEST_CASE("test interrupt eval") {
auto x = zeros({1024}, int32);
for (int i = 0; i < 1000; ++i) {
x = x + 1;
}
std::thread t([x]() { eval(x); });
interrupt_eval();
t.join();
// Check that x is not evaluated
CHECK(!x.is_available());
CHECK(array_equal(x, full({1024}, 1000, int32)).item<bool>());
}