only interrupt during an eval

This commit is contained in:
Awni Hannun 2025-03-19 07:56:26 -07:00
parent 9ffe88841c
commit 688e421184
3 changed files with 48 additions and 11 deletions

View File

@ -2,6 +2,7 @@
#include <algorithm> #include <algorithm>
#include <deque> #include <deque>
#include <future> #include <future>
#include <mutex>
#include <numeric> #include <numeric>
#include <set> #include <set>
#include <sstream> #include <sstream>
@ -35,10 +36,41 @@ class Synchronizer : public Primitive {
DEFINE_PRINT(Synchronize); DEFINE_PRINT(Synchronize);
}; };
std::atomic<bool>& interrupt_flag() { class Interrupt {
static std::atomic<bool> interrupt_{false}; private:
return interrupt_; static std::mutex mutex_;
} static bool eval_running_;
static bool interrupt_;
public:
Interrupt() {
std::unique_lock lk(mutex_);
eval_running_ = true;
}
static bool interrupt() {
std::unique_lock lk(mutex_);
if (eval_running_) {
interrupt_ = true;
return true;
}
return false;
}
static bool interrupted() {
std::unique_lock lk(mutex_);
return interrupt_;
}
~Interrupt() {
std::unique_lock lk(mutex_);
eval_running_ = false;
interrupt_ = false;
}
};
std::mutex Interrupt::mutex_{};
bool Interrupt::eval_running_ = false;
bool Interrupt::interrupt_ = false;
// Initialize the static tracing members from transforms_impl.h // Initialize the static tracing members from transforms_impl.h
// //
@ -50,6 +82,8 @@ std::vector<char> detail::InTracing::trace_stack{};
int detail::RetainGraph::tracing_counter{0}; int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) { array eval_impl(std::vector<array> outputs, bool async) {
Interrupt interrupt;
std::deque<array> tape; std::deque<array> tape;
// Make an effort to choose a good output stream // Make an effort to choose a good output stream
@ -255,8 +289,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
arr.detach(); arr.detach();
} }
if (interrupt_flag()) { if (Interrupt::interrupted()) {
interrupt_flag() = false;
synchronizer.attach_event(Event{stream}); synchronizer.attach_event(Event{stream});
break; break;
} }
@ -274,8 +307,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
return synchronizer; return synchronizer;
} }
void interrupt_eval() { bool interrupt_eval() {
interrupt_flag() = true; return Interrupt::interrupt();
} }
void async_eval(std::vector<array> outputs) { void async_eval(std::vector<array> outputs) {

View File

@ -23,9 +23,12 @@ void eval(Arrays&&... outputs) {
} }
/** /**
* Interrupt an ongoing eval. Leaves the graph in a valid state. * Interrupt an ongoing eval.
*
* Leaves the graph in a valid state. Returns true if an ongoing eval was
* interrupted and false otherwise.
*/ */
void interrupt_eval(); bool interrupt_eval();
/** /**
* Computes the output and vector-Jacobian product (VJP) of a function. * Computes the output and vector-Jacobian product (VJP) of a function.

View File

@ -97,7 +97,8 @@ TEST_CASE("test interrupt eval") {
x = x + 1; x = x + 1;
} }
std::thread t([x]() { eval(x); }); std::thread t([x]() { eval(x); });
interrupt_eval(); while (!interrupt_eval()) {
}
t.join(); t.join();
// Check that x is not evaluated // Check that x is not evaluated
CHECK(!x.is_available()); CHECK(!x.is_available());