mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 15:11:14 +08:00
interruptable eval
This commit is contained in:
parent
0a9777aa5c
commit
9ffe88841c
@ -35,6 +35,11 @@ class Synchronizer : public Primitive {
|
|||||||
DEFINE_PRINT(Synchronize);
|
DEFINE_PRINT(Synchronize);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::atomic<bool>& interrupt_flag() {
|
||||||
|
static std::atomic<bool> interrupt_{false};
|
||||||
|
return interrupt_;
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the static tracing members from transforms_impl.h
|
// Initialize the static tracing members from transforms_impl.h
|
||||||
//
|
//
|
||||||
// These are used to implement the in_tracing() function the returns true if we
|
// These are used to implement the in_tracing() function the returns true if we
|
||||||
@ -249,6 +254,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (interrupt_flag()) {
|
||||||
|
interrupt_flag() = false;
|
||||||
|
synchronizer.attach_event(Event{stream});
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal the event in its stream
|
// Signal the event in its stream
|
||||||
@ -263,6 +274,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
return synchronizer;
|
return synchronizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void interrupt_eval() {
|
||||||
|
interrupt_flag() = true;
|
||||||
|
}
|
||||||
|
|
||||||
void async_eval(std::vector<array> outputs) {
|
void async_eval(std::vector<array> outputs) {
|
||||||
if (outputs.empty()) {
|
if (outputs.empty()) {
|
||||||
return;
|
return;
|
||||||
|
@ -22,6 +22,11 @@ void eval(Arrays&&... outputs) {
|
|||||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interrupt an ongoing eval. Leaves the graph in a valid state.
|
||||||
|
*/
|
||||||
|
void interrupt_eval();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the output and vector-Jacobian product (VJP) of a function.
|
* Computes the output and vector-Jacobian product (VJP) of a function.
|
||||||
*
|
*
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
@ -91,3 +90,16 @@ TEST_CASE("test eval graph retention when not tracing") {
|
|||||||
CHECK(!a.has_primitive());
|
CHECK(!a.has_primitive());
|
||||||
CHECK(a.is_available());
|
CHECK(a.is_available());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test interrupt eval") {
|
||||||
|
auto x = zeros({1024}, int32);
|
||||||
|
for (int i = 0; i < 1000; ++i) {
|
||||||
|
x = x + 1;
|
||||||
|
}
|
||||||
|
std::thread t([x]() { eval(x); });
|
||||||
|
interrupt_eval();
|
||||||
|
t.join();
|
||||||
|
// Check that x is not evaluated
|
||||||
|
CHECK(!x.is_available());
|
||||||
|
CHECK(array_equal(x, full({1024}, 1000, int32)).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user