This commit is contained in:
Awni Hannun 2025-06-16 19:14:17 +02:00 committed by GitHub
commit 0aa9ccd158
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 1 deletions

View File

@ -2,6 +2,7 @@
#include <algorithm>
#include <deque>
#include <future>
#include <mutex>
#include <numeric>
#include <set>
#include <sstream>
@ -36,6 +37,42 @@ class Synchronizer : public Primitive {
DEFINE_PRINT(Synchronize);
};
class Interrupt {
private:
static std::mutex mutex_;
static bool eval_running_;
static bool interrupt_;
public:
Interrupt() {
std::unique_lock lk(mutex_);
eval_running_ = true;
}
static bool interrupt() {
std::unique_lock lk(mutex_);
if (eval_running_) {
interrupt_ = true;
return true;
}
return false;
}
static bool interrupted() {
std::unique_lock lk(mutex_);
return interrupt_;
}
~Interrupt() {
std::unique_lock lk(mutex_);
eval_running_ = false;
interrupt_ = false;
}
};
std::mutex Interrupt::mutex_{};
bool Interrupt::eval_running_ = false;
bool Interrupt::interrupt_ = false;
// Initialize the static tracing members from transforms_impl.h
//
// These are used to implement the in_tracing() function the returns true if we
@ -50,6 +87,8 @@ int detail::InTracing::grad_counter{0};
int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
Interrupt interrupt;
std::deque<array> tape;
// Make an effort to choose a good output stream
@ -260,6 +299,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
if (!arr.is_tracer()) {
arr.detach();
}
if (Interrupt::interrupted()) {
synchronizer.attach_event(Event{stream});
break;
}
}
// Signal the event in its stream
@ -274,6 +318,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
return synchronizer;
}
bool interrupt_eval() {
return Interrupt::interrupt();
}
void async_eval(std::vector<array> outputs) {
if (outputs.empty()) {
return;

View File

@ -22,6 +22,14 @@ void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}
/**
* Interrupt an ongoing eval.
*
* Leaves the graph in a valid state. Returns true if an ongoing eval was
* interrupted and false otherwise.
*/
bool interrupt_eval();
/**
* Computes the output and vector-Jacobian product (VJP) of a function.
*

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
@ -91,3 +90,17 @@ TEST_CASE("test eval graph retention when not tracing") {
CHECK(!a.has_primitive());
CHECK(a.is_available());
}
TEST_CASE("test interrupt eval") {
auto x = zeros({1024}, int32);
for (int i = 0; i < 1000; ++i) {
x = x + 1;
}
std::thread t([x]() { eval(x); });
while (!interrupt_eval()) {
}
t.join();
// Check that x is not evaluated
CHECK(!x.is_available());
CHECK(array_equal(x, full({1024}, 1000, int32)).item<bool>());
}