Fix array is_available race cases (#1468)

This commit is contained in:
Awni Hannun 2024-10-07 19:13:50 -07:00 committed by GitHub
parent 9b12093739
commit 3274c6a087
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 77 additions and 23 deletions

View File

@ -14,7 +14,7 @@ option(MLX_BUILD_TESTS "Build tests for mlx" ON)
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" OFF) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
@ -47,13 +47,13 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source" "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
) )
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_METAL ON)
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()

View File

@ -95,13 +95,29 @@ void array::detach() {
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
void array::eval() { bool array::is_available() const {
// Ensure the array is ready to be read if (status() == Status::available) {
if (status() == Status::scheduled) { return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait(); event().wait();
set_status(Status::available); set_status(Status::available);
} else if (status() == Status::unscheduled) { }
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this}); mlx::core::eval({*this});
} else {
wait();
} }
} }

View File

@ -344,11 +344,33 @@ class array {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
} }
enum Status { unscheduled, scheduled, available }; enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
bool is_available() const { // The ouptut of a computation which has been scheduled but `eval_*` has
return status() == Status::available; // not yet been called on the array's primitive. A possible
} // status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const { Status status() const {
return array_desc_->status; return array_desc_->status;

View File

@ -27,4 +27,9 @@ void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value()); static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
} }
bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
value();
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -74,7 +74,7 @@ std::function<void()> make_task(array arr, bool signal) {
arr.detach(); arr.detach();
} }
for (auto& out : outputs) { for (auto& out : outputs) {
out.set_status(array::Status::available); out.set_status(array::Status::evaluated);
} }
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {

View File

@ -36,4 +36,11 @@ void Event::signal() {
ec->cv.notify_all(); ec->cv.notify_all();
} }
bool Event::is_signaled() const {
auto ec = static_cast<EventCounter*>(raw_event().get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
return (ec->value > value());
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -20,6 +20,9 @@ class Event {
// Signal the event at its current value // Signal the event at its current value
void signal(); void signal();
// Check if the event has been signaled at its current value
bool is_signaled() const;
// Check if the event is valid // Check if the event is valid
bool valid() const { bool valid() const {
return event_ != nullptr; return event_ != nullptr;

View File

@ -79,7 +79,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
continue; continue;
} }
if (!in.is_available()) { if (in.status() == array::Status::unscheduled) {
if (async && in.is_tracer()) { if (async && in.is_tracer()) {
throw std::invalid_argument( throw std::invalid_argument(
"[async_eval] Not allowed inside a graph transformation."); "[async_eval] Not allowed inside a graph transformation.");
@ -115,7 +115,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
// All inputs are done being processed, process this array // All inputs are done being processed, process this array
if (a.is_available() && !a.is_tracer() && a.has_primitive()) { if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
a.has_primitive()) {
// If the array is evaluated and is no longer a tracer, detach it // If the array is evaluated and is no longer a tracer, detach it
a.detach(); a.detach();
} else if (a.status() == array::Status::unscheduled) { } else if (a.status() == array::Status::unscheduled) {
@ -208,14 +209,12 @@ void eval(std::vector<array> outputs) {
return x.status() == array::Status::unscheduled; return x.status() == array::Status::unscheduled;
})) { })) {
for (auto& x : outputs) { for (auto& x : outputs) {
if (!x.is_available()) { x.wait();
x.event().wait();
}
} }
return; return;
} }
eval_impl(std::move(outputs), false).event().wait(); eval_impl(std::move(outputs), false).wait();
} }
std::pair<std::vector<array>, std::vector<array>> vjp( std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@ -51,11 +51,13 @@ class TestEval(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12]))) self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
def test_async_eval_twice(self): def test_async_eval_twice(self):
x = mx.array(1) + mx.array(1) + mx.array(1) for _ in range(1000):
mx.async_eval(x) x = mx.array(1) + mx.array(1) + mx.array(1)
y = x + 1 mx.async_eval(x)
mx.async_eval(y) y = x + 1
self.assertEqual(x.item(), 3) mx.async_eval(y)
self.assertEqual(x.item(), 3)
self.assertEqual(y.item(), 4)
def test_async_eval_in_trace(self): def test_async_eval_in_trace(self):
def fun(x): def fun(x):