mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
only interrupt during an eval
This commit is contained in:
parent
9ffe88841c
commit
688e421184
@ -2,6 +2,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <future>
|
#include <future>
|
||||||
|
#include <mutex>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -35,11 +36,42 @@ class Synchronizer : public Primitive {
|
|||||||
DEFINE_PRINT(Synchronize);
|
DEFINE_PRINT(Synchronize);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::atomic<bool>& interrupt_flag() {
|
class Interrupt {
|
||||||
static std::atomic<bool> interrupt_{false};
|
private:
|
||||||
|
static std::mutex mutex_;
|
||||||
|
static bool eval_running_;
|
||||||
|
static bool interrupt_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Interrupt() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
eval_running_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool interrupt() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
if (eval_running_) {
|
||||||
|
interrupt_ = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool interrupted() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
return interrupt_;
|
return interrupt_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
~Interrupt() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
eval_running_ = false;
|
||||||
|
interrupt_ = false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::mutex Interrupt::mutex_{};
|
||||||
|
bool Interrupt::eval_running_ = false;
|
||||||
|
bool Interrupt::interrupt_ = false;
|
||||||
|
|
||||||
// Initialize the static tracing members from transforms_impl.h
|
// Initialize the static tracing members from transforms_impl.h
|
||||||
//
|
//
|
||||||
// These are used to implement the in_tracing() function the returns true if we
|
// These are used to implement the in_tracing() function the returns true if we
|
||||||
@ -50,6 +82,8 @@ std::vector<char> detail::InTracing::trace_stack{};
|
|||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
|
Interrupt interrupt;
|
||||||
|
|
||||||
std::deque<array> tape;
|
std::deque<array> tape;
|
||||||
|
|
||||||
// Make an effort to choose a good output stream
|
// Make an effort to choose a good output stream
|
||||||
@ -255,8 +289,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (interrupt_flag()) {
|
if (Interrupt::interrupted()) {
|
||||||
interrupt_flag() = false;
|
|
||||||
synchronizer.attach_event(Event{stream});
|
synchronizer.attach_event(Event{stream});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -274,8 +307,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
return synchronizer;
|
return synchronizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void interrupt_eval() {
|
bool interrupt_eval() {
|
||||||
interrupt_flag() = true;
|
return Interrupt::interrupt();
|
||||||
}
|
}
|
||||||
|
|
||||||
void async_eval(std::vector<array> outputs) {
|
void async_eval(std::vector<array> outputs) {
|
||||||
|
@ -23,9 +23,12 @@ void eval(Arrays&&... outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interrupt an ongoing eval. Leaves the graph in a valid state.
|
* Interrupt an ongoing eval.
|
||||||
|
*
|
||||||
|
* Leaves the graph in a valid state. Returns true if an ongoing eval was
|
||||||
|
* interrupted and false otherwise.
|
||||||
*/
|
*/
|
||||||
void interrupt_eval();
|
bool interrupt_eval();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the output and vector-Jacobian product (VJP) of a function.
|
* Computes the output and vector-Jacobian product (VJP) of a function.
|
||||||
|
@ -97,7 +97,8 @@ TEST_CASE("test interrupt eval") {
|
|||||||
x = x + 1;
|
x = x + 1;
|
||||||
}
|
}
|
||||||
std::thread t([x]() { eval(x); });
|
std::thread t([x]() { eval(x); });
|
||||||
interrupt_eval();
|
while (!interrupt_eval()) {
|
||||||
|
}
|
||||||
t.join();
|
t.join();
|
||||||
// Check that x is not evaluated
|
// Check that x is not evaluated
|
||||||
CHECK(!x.is_available());
|
CHECK(!x.is_available());
|
||||||
|
Loading…
Reference in New Issue
Block a user