diff --git a/CMakeLists.txt b/CMakeLists.txt index 1688b2052..1d7e3d7af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) -option(MLX_BUILD_METAL "Build metal backend" OFF) +option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) @@ -47,13 +47,13 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source" ) else() + set(MLX_BUILD_METAL OFF) message(WARNING "Building for x86_64 arch is not officially supported.") endif() - elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64") - set(MLX_BUILD_METAL ON) endif() else() + set(MLX_BUILD_METAL OFF) message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") endif() diff --git a/mlx/array.cpp b/mlx/array.cpp index 99c6c23b7..e7040eb5f 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -95,13 +95,29 @@ void array::detach() { array_desc_->primitive = nullptr; } -void array::eval() { - // Ensure the array is ready to be read - if (status() == Status::scheduled) { +bool array::is_available() const { + if (status() == Status::available) { + return true; + } else if (status() == Status::evaluated && event().is_signaled()) { + set_status(Status::available); + return true; + } + return false; +} + +void array::wait() { + if (!is_available()) { event().wait(); set_status(Status::available); - } else if (status() == Status::unscheduled) { + } +} + +void array::eval() { + // Ensure the array is ready to be read + if (status() == Status::unscheduled) { mlx::core::eval({*this}); + } else { + wait(); } } diff --git a/mlx/array.h b/mlx/array.h index d6d6c921d..f41baf568 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -344,11 +344,33 @@ class array { return static_cast(array_desc_->data_ptr); } - enum Status { unscheduled, scheduled, available }; + enum Status { + // The ouptut of a computation which has not been scheduled. + // For example, the status of `x` in `auto x = a + b`. + unscheduled, - bool is_available() const { - return status() == Status::available; - } + // The ouptut of a computation which has been scheduled but `eval_*` has + // not yet been called on the array's primitive. A possible + // status of `x` in `auto x = a + b; eval(x);` + scheduled, + + // The array's `eval_*` function has been run, but the computation is not + // necessarily complete. The array will have memory allocated and if it is + // not a tracer then it will be detached from the graph. + evaluated, + + // If the array is the output of a computation then the computation + // is complete. Constant arrays are always available (e.g. `array({1, 2, + // 3})`) + available + }; + + // Check if the array is safe to read. + bool is_available() const; + + // Wait on the array to be available. After this `is_available` returns + // `true`. + void wait(); Status status() const { return array_desc_->status; diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index fde5749c5..be29d4533 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -27,4 +27,9 @@ void Event::signal() { static_cast(raw_event().get())->setSignaledValue(value()); } +bool Event::is_signaled() const { + return static_cast(raw_event().get())->signaledValue() >= + value(); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 8755e73f1..2a5e6334e 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -74,7 +74,7 @@ std::function make_task(array arr, bool signal) { arr.detach(); } for (auto& out : outputs) { - out.set_status(array::Status::available); + out.set_status(array::Status::evaluated); } if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_metal/event.cpp index 2945894ac..a41b4bb6e 100644 --- a/mlx/backend/no_metal/event.cpp +++ b/mlx/backend/no_metal/event.cpp @@ -36,4 +36,11 @@ void Event::signal() { ec->cv.notify_all(); } +bool Event::is_signaled() const { + auto ec = static_cast(raw_event().get()); + { + std::lock_guard lk(ec->mtx); + return (ec->value > value()); + } +} } // namespace mlx::core diff --git a/mlx/event.h b/mlx/event.h index b58b7cad2..e492370a2 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -20,6 +20,9 @@ class Event { // Signal the event at its current value void signal(); + // Check if the event has been signaled at its current value + bool is_signaled() const; + // Check if the event is valid bool valid() const { return event_ != nullptr; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 19387388e..54f676203 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -79,7 +79,7 @@ array eval_impl(std::vector outputs, bool async) { continue; } - if (!in.is_available()) { + if (in.status() == array::Status::unscheduled) { if (async && in.is_tracer()) { throw std::invalid_argument( "[async_eval] Not allowed inside a graph transformation."); @@ -115,7 +115,8 @@ array eval_impl(std::vector outputs, bool async) { } // All inputs are done being processed, process this array - if (a.is_available() && !a.is_tracer() && a.has_primitive()) { + if ((a.status() != array::Status::unscheduled) && !a.is_tracer() && + a.has_primitive()) { // If the array is evaluated and is no longer a tracer, detach it a.detach(); } else if (a.status() == array::Status::unscheduled) { @@ -208,14 +209,12 @@ void eval(std::vector outputs) { return x.status() == array::Status::unscheduled; })) { for (auto& x : outputs) { - if (!x.is_available()) { - x.event().wait(); - } + x.wait(); } return; } - eval_impl(std::move(outputs), false).event().wait(); + eval_impl(std::move(outputs), false).wait(); } std::pair, std::vector> vjp( diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 9f98c9600..81b51209a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -51,11 +51,13 @@ class TestEval(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12]))) def test_async_eval_twice(self): - x = mx.array(1) + mx.array(1) + mx.array(1) - mx.async_eval(x) - y = x + 1 - mx.async_eval(y) - self.assertEqual(x.item(), 3) + for _ in range(1000): + x = mx.array(1) + mx.array(1) + mx.array(1) + mx.async_eval(x) + y = x + 1 + mx.async_eval(y) + self.assertEqual(x.item(), 3) + self.assertEqual(y.item(), 4) def test_async_eval_in_trace(self): def fun(x):