diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a5074e86..c9de50875 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,8 +82,10 @@ elseif (MLX_BUILD_METAL) message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") if (${MACOS_VERSION} GREATER_EQUAL 14.2) + set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) + set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip) else() message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) @@ -92,6 +94,7 @@ elseif (MLX_BUILD_METAL) FetchContent_Declare( metal_cpp URL ${METAL_CPP_URL} + PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true ) FetchContent_MakeAvailable(metal_cpp) diff --git a/cmake/metal.14.0.diff b/cmake/metal.14.0.diff new file mode 100644 index 000000000..3609fd916 --- /dev/null +++ b/cmake/metal.14.0.diff @@ -0,0 +1,36 @@ +diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp +--- Metal/MTLEvent.hpp 2023-06-01 12:18:26 ++++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59 +@@ -62,6 +62,7 @@ + + uint64_t signaledValue() const; + void setSignaledValue(uint64_t signaledValue); ++ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS); + }; + + class SharedEventHandle : public NS::SecureCoding +@@ -138,6 +139,11 @@ + _MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) + { + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); ++} ++ ++// method: waitUntilSignaledValue ++_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) { ++ return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS); + } + + // static method: alloc +diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp +--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26 ++++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29 +@@ -1906,6 +1906,9 @@ + "setShouldMaximizeConcurrentCompilation:"); + _MTL_PRIVATE_DEF_SEL(setSignaledValue_, + "setSignaledValue:"); ++_MTL_PRIVATE_DEF_SEL( ++ waitUntilSignaledValue_timeoutMS_, ++ "waitUntilSignaledValue:timeoutMS:"); + _MTL_PRIVATE_DEF_SEL(setSize_, + "setSize:"); + _MTL_PRIVATE_DEF_SEL(setSlice_, diff --git a/cmake/metal.14.2.diff b/cmake/metal.14.2.diff new file mode 100644 index 000000000..8634afaa7 --- /dev/null +++ b/cmake/metal.14.2.diff @@ -0,0 +1,36 @@ +diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp +--- Metal/MTLEvent.hpp 2024-04-15 07:12:10 ++++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50 +@@ -62,6 +62,7 @@ + + uint64_t signaledValue() const; + void setSignaledValue(uint64_t signaledValue); ++ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS); + }; + + class SharedEventHandle : public NS::SecureCoding +@@ -138,6 +139,11 @@ + _MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) + { + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); ++} ++ ++// method: waitUntilSignaledValue ++_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) { ++ return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS); + } + + // static method: alloc +diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp +--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10 ++++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15 +@@ -1918,6 +1918,9 @@ + "setShouldMaximizeConcurrentCompilation:"); + _MTL_PRIVATE_DEF_SEL(setSignaledValue_, + "setSignaledValue:"); ++_MTL_PRIVATE_DEF_SEL( ++ waitUntilSignaledValue_timeoutMS_, ++ "waitUntilSignaledValue:timeoutMS:"); + _MTL_PRIVATE_DEF_SEL(setSize_, + "setSize:"); + _MTL_PRIVATE_DEF_SEL(setSlice_, diff --git a/mlx/array.cpp b/mlx/array.cpp index ff058a833..f655bc6a6 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -93,7 +93,11 @@ void array::detach() { } void array::eval() { - if (!is_evaled()) { + // Ensure the array is ready to be read + if (status() == Status::scheduled) { + event().wait(); + set_status(Status::available); + } else if (status() == Status::unscheduled) { mlx::core::eval({*this}); } } @@ -176,7 +180,7 @@ void array::ArrayDesc::init() { } array::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype) - : shape(std::move(shape)), dtype(dtype) { + : shape(std::move(shape)), dtype(dtype), status(Status::available) { init(); } @@ -187,6 +191,7 @@ array::ArrayDesc::ArrayDesc( std::vector inputs) : shape(std::move(shape)), dtype(dtype), + status(Status::unscheduled), primitive(std::move(primitive)), inputs(std::move(inputs)) { init(); diff --git a/mlx/array.h b/mlx/array.h index a3b2b2c44..aeb76d9c8 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -9,6 +9,7 @@ #include "mlx/allocator.h" #include "mlx/dtype.h" +#include "mlx/event.h" namespace mlx::core { @@ -315,9 +316,27 @@ class array { return static_cast(array_desc_->data_ptr); }; - // Check if the array has been evaluated - bool is_evaled() const { - return array_desc_->data != nullptr; + enum Status { unscheduled, scheduled, available }; + + bool is_available() const { + return status() == Status::available; + } + const Status status() const { + return array_desc_->status; + } + + void set_status(Status s) const { + array_desc_->status = s; + } + + // Get the array's shared event + Event& event() const { + return array_desc_->event; + } + + // Attach an event to a not yet evaluated array + void attach_event(Event e) const { + array_desc_->event = std::move(e); } // Mark the array as a tracer array (true) or not. @@ -370,6 +389,11 @@ class array { Dtype dtype; std::shared_ptr primitive; + Status status; + + // An event on the array used for synchronization + Event event; + // Indicates an array is being used in a graph transform // and should not be detached from the graph bool is_tracer{false}; @@ -470,10 +494,11 @@ T array::item() const { if (size() != 1) { throw std::invalid_argument("item can only be called on arrays of size 1."); } - if (!is_evaled()) { + if (status() == Status::unscheduled) { throw std::invalid_argument( "item() const can only be called on evaled arrays"); } + const_cast(this)->eval(); return *data(); } diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index f6f374686..0a77f0bda 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -26,6 +26,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 1ee9335c0..23dae8d51 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -544,11 +544,12 @@ Device& device(mlx::core::Device) { return metal_device; } -std::shared_ptr new_scoped_memory_pool() { +std::unique_ptr> new_scoped_memory_pool() { auto dtor = [](void* ptr) { static_cast(ptr)->release(); }; - return std::shared_ptr(NS::AutoreleasePool::alloc()->init(), dtor); + return std::unique_ptr>( + NS::AutoreleasePool::alloc()->init(), dtor); } void new_stream(Stream stream) { diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp new file mode 100644 index 000000000..fde5749c5 --- /dev/null +++ b/mlx/backend/metal/event.cpp @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/event.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal_impl.h" + +namespace mlx::core { + +Event::Event(const Stream& stream) : stream_(stream) { + auto dtor = [](void* ptr) { + auto p = metal::new_scoped_memory_pool(); + static_cast(ptr)->release(); + }; + auto p = metal::new_scoped_memory_pool(); + event_ = std::shared_ptr( + metal::device(stream.device).mtl_device()->newSharedEvent(), dtor); +} + +void Event::wait() { + if (!static_cast(raw_event().get()) + ->waitUntilSignaledValue(value(), -1)) { + throw std::runtime_error("[Event::wait] Timed out"); + } +} + +void Event::signal() { + static_cast(raw_event().get())->setSignaledValue(value()); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index c18e5c658..1de8ceec5 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -55,17 +55,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } } -std::function make_task( - array& arr, - std::vector> deps, - std::shared_ptr> p) { - auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable { +std::function make_task(array arr, bool signal) { + auto task = [arr = std::move(arr), signal]() mutable { auto pool = new_scoped_memory_pool(); - for (auto& d : deps) { - d.wait(); - } auto s = arr.primitive().stream(); auto command_buffer = increment_command_buffer(s); + for (auto& input : arr.inputs()) { + if (input.event().valid() && + input.event().stream() != arr.primitive().stream()) { + // TODO, consider committing the buffer and encoding a wait in the new + // buffer rather than on the task thread + input.event().wait(); + } + } + auto outputs = arr.outputs(); { // If the array is a tracer hold a reference @@ -88,13 +91,16 @@ std::function make_task( if (!arr.is_tracer()) { arr.detach(); } - if (p) { + + if (signal) { metal::device(s.device).end_encoding(s.index); + command_buffer->encodeSignalEvent( + static_cast(arr.event().raw_event().get()), + arr.event().value()); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers), p = std::move(p)]( + [s, buffers = std::move(buffers), event = arr.event()]( MTL::CommandBuffer* cbuf) { - p->set_value(); scheduler::notify_task_completion(s); check_error(cbuf); }); diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/metal/metal_impl.h index 3487558b8..885dd33e3 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/metal/metal_impl.h @@ -2,9 +2,7 @@ #pragma once -#include #include -#include #include "mlx/array.h" #include "mlx/stream.h" @@ -12,11 +10,9 @@ namespace mlx::core::metal { void new_stream(Stream stream); -std::shared_ptr new_scoped_memory_pool(); -std::function make_task( - array& arr, - std::vector> deps, - std::shared_ptr> p); +std::unique_ptr> new_scoped_memory_pool(); + +std::function make_task(array arr, bool signal); } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_metal/CMakeLists.txt index 6aaa766d6..8f507e771 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_metal/CMakeLists.txt @@ -2,6 +2,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ) diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_metal/event.cpp new file mode 100644 index 000000000..2945894ac --- /dev/null +++ b/mlx/backend/no_metal/event.cpp @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/event.h" + +#include +#include + +namespace mlx::core { + +struct EventCounter { + uint64_t value{0}; + std::mutex mtx; + std::condition_variable cv; +}; + +Event::Event(const Stream& stream) : stream_(stream) { + auto dtor = [](void* ptr) { delete static_cast(ptr); }; + event_ = std::shared_ptr(new EventCounter{}, dtor); +} + +void Event::wait() { + auto ec = static_cast(raw_event().get()); + std::unique_lock lk(ec->mtx); + if (ec->value >= value()) { + return; + } + ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); +} + +void Event::signal() { + auto ec = static_cast(raw_event().get()); + { + std::lock_guard lk(ec->mtx); + ec->value = value(); + } + ec->cv.notify_all(); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index fe177a467..4993a7313 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -12,14 +12,12 @@ bool is_available() { } void new_stream(Stream) {} -std::shared_ptr new_scoped_memory_pool() { + +std::unique_ptr> new_scoped_memory_pool() { return nullptr; } -std::function make_task( - array& arr, - std::vector> deps, - std::shared_ptr> p) { +std::function make_task(array arr, bool signal) { throw std::runtime_error( "[metal::make_task] Cannot make GPU task without metal backend"); } diff --git a/mlx/compile.cpp b/mlx/compile.cpp index d187c1e10..456f658a1 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -352,7 +352,8 @@ void compile_simplify( // Helpers to identify identical scalars std::map, array> scalars; auto is_scalar = [](const array& a) { - return a.is_evaled() && a.ndim() == 0; + // Condition for when it's safe to read an array + return a.is_available() && a.ndim() == 0; }; auto get_scalar_rep = [](const array& a) { uint64_t v = 0; diff --git a/mlx/event.h b/mlx/event.h new file mode 100644 index 000000000..4fee164eb --- /dev/null +++ b/mlx/event.h @@ -0,0 +1,56 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include + +#include "mlx/stream.h" + +namespace mlx::core { + +class Event { + public: + Event(){}; + + Event(const Stream& steam); + + // Wait for the event to be signaled at its curent value + void wait(); + + // Signal the event at its current value + void signal(); + + // Check if the event is valid + bool valid() { + return event_ != nullptr; + }; + + uint64_t value() { + return value_; + }; + + void set_value(uint64_t v) { + value_ = v; + }; + + const Stream& stream() { + if (!valid()) { + throw std::runtime_error( + "[Event::stream] Cannot access stream on invalid event."); + } + return stream_; + }; + + const std::shared_ptr& raw_event() { + return event_; + }; + + private: + // Default constructed stream should never be used + // since the event is not yet valid + Stream stream_{0, Device::cpu}; + std::shared_ptr event_{nullptr}; + uint64_t value_{0}; +}; + +} // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f6ab6f747..4c497ac9b 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -36,29 +36,32 @@ class Synchronizer : public Primitive { // are currently under a function transformation. int detail::InTracing::tracing_counter{0}; -std::shared_future async_eval(std::vector outputs) { - static std::shared_future global_synchronizer; - // Catch up with previous async eval if needed - if (global_synchronizer.valid()) { - global_synchronizer.wait(); - } +array eval_impl(std::vector outputs, bool async) { std::queue tape; - std::unordered_set cache; - std::unordered_map> deps; + + // stream events to use for synchronization + std::unordered_map events; // Make an effort to choose a good output stream Stream stream = default_stream(default_device()); for (auto& o : outputs) { - if (!o.is_evaled() && o.has_primitive()) { + if (o.status() == array::Status::unscheduled && o.has_primitive()) { stream = o.primitive().stream(); break; } } + std::unordered_set needs_signal; + auto synchronizer = array( {}, bool_, std::make_shared(stream), std::move(outputs)); + needs_signal.insert(synchronizer.id()); + + // Make an event for the synchronizer stream + events.emplace(stream.index, Event{stream}); { + std::unordered_set cache; std::stack, int>> dfs; dfs.emplace(synchronizer, 0); while (!dfs.empty()) { @@ -67,16 +70,23 @@ std::shared_future async_eval(std::vector outputs) { if (idx < a.inputs().size()) { // Add an input, and continue auto& in = a.inputs()[idx++]; - if (!in.is_evaled()) { + + // Ignore arrays already scheduled + if (in.status() == array::Status::scheduled) { + continue; + } + + if (!in.is_available()) { + if (async && in.is_tracer()) { + throw std::invalid_argument( + "[async_eval] Not allowed inside a graph transformation."); + } if (!in.has_primitive()) { throw std::invalid_argument( "[eval] Attempting to eval an array without a primitive."); } - - // If the input is being computed on a different stream, we need to - // manage the dependency. if (a.primitive().stream() != in.primitive().stream()) { - deps.insert({in.output(0).id(), std::shared_future{}}); + needs_signal.insert(in.id()); } } @@ -91,52 +101,54 @@ std::shared_future async_eval(std::vector outputs) { } // All inputs are done being processed, process this array - if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { + if (a.is_available() && !a.is_tracer() && a.has_primitive()) { + // If the array is evaluated and is no longer a tracer, detach it + a.detach(); + } else if (a.status() == array::Status::unscheduled) { tape.push(a); + // Lookup corresponding event and increment counter + auto& stream = a.primitive().stream(); + auto e = events.find(stream.index); + if (e == events.end()) { + e = events.emplace(stream.index, Event{stream}).first; + } + e->second.set_value(e->second.value() + 1); + a.attach_event(e->second); + for (auto& s : a.siblings()) { + s.attach_event(e->second); + } } dfs.pop(); } } - deps.insert({synchronizer.id(), std::shared_future{}}); - std::vector>> ps; while (!tape.empty()) { auto arr = std::move(tape.front()); tape.pop(); - if (arr.is_evaled()) { - if (!arr.is_tracer() && arr.has_primitive()) { - arr.detach(); - } - continue; + + // Set the status of the array and siblings. + auto status = async ? array::Status::scheduled : array::Status::available; + arr.set_status(status); + for (auto& s : arr.siblings()) { + s.set_status(status); } auto stream = arr.primitive().stream(); std::vector> arr_deps; - for (auto& in : arr.inputs()) { - if (auto it = deps.find(in.output(0).id()); it != deps.end()) { - arr_deps.push_back(it->second); - } - } - std::shared_ptr> p; - if (auto it = deps.find(arr.output(0).id()); it != deps.end()) { - p = std::make_shared>(); - ps.push_back(p); - it->second = p->get_future().share(); - } + bool signal = needs_signal.find(arr.id()) != needs_signal.end(); if (arr.primitive().device() == Device::gpu) { if (!metal::is_available()) { throw std::runtime_error("Metal GPU is not available."); } - scheduler::enqueue( - stream, metal::make_task(arr, std::move(arr_deps), std::move(p))); + scheduler::enqueue(stream, metal::make_task(std::move(arr), signal)); } else { - auto task = [arr, - stream, - deps = std::move(arr_deps), - p = std::move(p)]() mutable { - for (auto& d : deps) { - d.wait(); + auto task = [arr = std::move(arr), stream, signal]() mutable { + for (auto& input : arr.inputs()) { + if (input.event().valid() && + input.event().stream() != arr.primitive().stream()) { + input.event().wait(); + } } scheduler::notify_new_task(stream); auto outputs = arr.outputs(); @@ -144,20 +156,24 @@ std::shared_future async_eval(std::vector outputs) { if (!arr.is_tracer()) { arr.detach(); } - if (p) { - p->set_value(); + if (signal) { + arr.event().signal(); } + scheduler::notify_task_completion(stream); }; scheduler::enqueue(stream, std::move(task)); } } - global_synchronizer = std::move(deps[synchronizer.id()]); - return global_synchronizer; + return synchronizer; +} + +void async_eval(std::vector outputs) { + eval_impl(std::move(outputs), true); } void eval(std::vector outputs) { - async_eval(std::move(outputs)).wait(); + eval_impl(std::move(outputs), false).event().wait(); } std::pair, std::vector> vjp( diff --git a/mlx/transforms.h b/mlx/transforms.h index eb2c26780..d64f6060e 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -2,12 +2,11 @@ #pragma once -#include #include "mlx/array.h" namespace mlx::core { -std::shared_future async_eval(std::vector outputs); +void async_eval(std::vector outputs); void eval(std::vector outputs); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 07ef696f0..34882884b 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -264,9 +264,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { } std::ostream& operator<<(std::ostream& os, array a) { - if (!a.is_evaled()) { - a.eval(); - } + a.eval(); switch (a.dtype()) { case bool_: print_array(os, a); diff --git a/python/src/array.cpp b/python/src/array.cpp index 789b8c00f..b0f1e92b5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -946,10 +946,7 @@ void init_array(nb::module_& m) { .def( "__repr__", [](array& a) { - if (!a.is_evaled()) { - nb::gil_scoped_release nogil; - a.eval(); - } + nb::gil_scoped_release nogil; std::ostringstream os; os << a; return os.str(); diff --git a/python/src/buffer.h b/python/src/buffer.h index 500236789..112fd7aaf 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -86,7 +86,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { std::memset(view, 0, sizeof(Py_buffer)); auto a = nb::cast(nb::handle(obj)); - if (!a.is_evaled()) { + { nb::gil_scoped_release nogil; a.eval(); } diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a27e6313e..cee16da80 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -104,8 +104,7 @@ template nb::ndarray mlx_to_nd_array( array a, std::optional t = {}) { - // Eval if not already evaled - if (!a.is_evaled()) { + { nb::gil_scoped_release nogil; a.eval(); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 32df46e1e..9e89d9b45 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -595,14 +595,6 @@ class PyCheckpointedFun { }; void init_transforms(nb::module_& m) { - nb::class_>( - m, - "Synchronizer", - R"pbdoc( - A synchronization object returned by :func:`async_eval`. - )pbdoc") - .def("wait", [](const std::shared_future& f) { f.wait(); }); - m.def( "eval", [](const nb::args& args) { @@ -629,19 +621,14 @@ void init_transforms(nb::module_& m) { std::vector arrays = tree_flatten(args, false); { nb::gil_scoped_release nogil; - return async_eval(arrays); + async_eval(arrays); } }, nb::arg(), - nb::sig("def async_eval(*args) -> Synchronizer"), + nb::sig("def async_eval(*args)"), R"pbdoc( Asynchronously evaluate an :class:`array` or tree of :class:`array`. - .. warning:: - - You must call ``wait`` on the returned synchronization object before - using any arrays that are asynchronously evaluated. - .. note:: This is an experimental API and may change in future versions. @@ -652,8 +639,17 @@ void init_transforms(nb::module_& m) { :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not arrays are ignored. - Returns: - Synchronizer: A synchronization object. + Example: + >>> x = mx.array(1.0) + >>> y = mx.exp(x) + >>> mx.async_eval(y) + >>> print(y) + >>> + >>> y = mx.exp(x) + >>> mx.async_eval(y) + >>> z = y + 3 + >>> mx.async_eval(z) + >>> print(z) )pbdoc"); m.def( "jvp", diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index d4930ae81..7b972ac1c 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -34,16 +34,75 @@ class TestEval(mlx_tests.MLXTestCase): def test_async_eval(self): x = mx.array(1) + mx.array(1) + mx.array(1) - sync = mx.async_eval(x) - sync.wait() + mx.async_eval(x) self.assertEqual(x.item(), 3) # It should be safe to call eval on the array which has been async # eval'ed x = mx.array(1) + mx.array(1) + mx.array(1) - sync = mx.async_eval(x) self.assertEqual(x.item(), 3) + x = mx.array([1, 2, 3]) + y = 2 * x + mx.async_eval(y) + z = 2 * y + mx.async_eval(z) + self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6]))) + self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12]))) + + def test_async_eval_twice(self): + x = mx.array(1) + mx.array(1) + mx.array(1) + mx.async_eval(x) + y = x + 1 + mx.async_eval(y) + self.assertEqual(x.item(), 3) + + def test_async_eval_in_trace(self): + def fun(x): + y = x + 1.0 + mx.async_eval(y) + return mx.exp(y) + + # Raises + with self.assertRaises(ValueError): + mx.grad(fun)(mx.array(1.0)) + + # Also raises + with self.assertRaises(ValueError): + mx.vmap(fun)(mx.ones((2, 2))) + + def test_async_eval_into_eval(self): + x = mx.array(1) + y = x + 1 + mx.async_eval(y) + a = y - 10 + b = mx.abs(a) + self.assertEqual(b.item(), 8) + + def test_async_eval_into_eval_diff_stream(self): + s = mx.new_stream(mx.cpu) + x = mx.array(0) + y = x - 5 + mx.async_eval(y) + z = mx.abs(y, stream=s) + self.assertEqual(z.item(), 5) + + def test_eval_slow_fast_multi_stream(self): + x = mx.ones((8000,)) + y = mx.abs(mx.array(-1.0)) + for _ in range(20): + x = x + mx.array(1.0) + z = mx.add(x, y, stream=mx.cpu) + self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0))) + + # Switch eval order + x = mx.ones((8000,)) + y = mx.abs(mx.array(-1.0)) + for _ in range(20): + x = x + mx.array(1.0) + z = mx.add(y, x, stream=mx.cpu) + self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index 53b269772..51bceb38f 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -24,12 +24,12 @@ class TestMetal(mlx_tests.MLXTestCase): self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit) # Query active and peak memory - a = mx.zeros((4096,)) + a = mx.zeros((4096,), stream=mx.cpu) mx.eval(a) active_mem = mx.metal.get_active_memory() self.assertTrue(active_mem >= 4096 * 4) - b = mx.zeros((4096,)) + b = mx.zeros((4096,), stream=mx.cpu) mx.eval(b) del b diff --git a/tests/arg_reduce_tests.cpp b/tests/arg_reduce_tests.cpp index 7fa01d837..55b966f78 100644 --- a/tests/arg_reduce_tests.cpp +++ b/tests/arg_reduce_tests.cpp @@ -49,7 +49,6 @@ TEST_CASE("test arg reduce small") { {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, {2, 3, 4}); - x.eval(); test_arg_reduce_small( Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3}); test_arg_reduce_small( diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 1357d40a9..0d01f5d10 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -100,7 +100,7 @@ TEST_CASE("test jvp") { auto fun1 = [](const array& x) { auto y = 3 * x; eval(y); - CHECK(y.is_evaled()); + CHECK(y.is_available()); CHECK(y.has_primitive()); CHECK(y.is_tracer()); return 2 * y; @@ -253,7 +253,7 @@ TEST_CASE("test grad") { eval(y); CHECK(x.is_tracer()); CHECK(!y.is_tracer()); - CHECK(y.is_evaled()); + CHECK(y.is_available()); CHECK(!y.has_primitive()); return square(x); }; @@ -265,7 +265,7 @@ TEST_CASE("test grad") { x = x + 2.0f; eval(x); CHECK(x.is_tracer()); - CHECK(x.is_evaled()); + CHECK(x.is_available()); CHECK(x.has_primitive()); return square(x); }; @@ -1259,7 +1259,7 @@ TEST_CASE("test update state") { grad(fn)(y); eval(state); CHECK(!state.has_primitive()); - CHECK(state.is_evaled()); + CHECK(state.is_available()); CHECK(array_equal(state, array({1.0, 1.0})).item()); } diff --git a/tests/eval_tests.cpp b/tests/eval_tests.cpp index 1c0ba857f..ce4bc980f 100644 --- a/tests/eval_tests.cpp +++ b/tests/eval_tests.cpp @@ -56,13 +56,13 @@ TEST_CASE("test eval with tracer when not tracing") { CHECK(!x.is_tracer()); eval(x); CHECK(!x.has_primitive()); - CHECK(x.is_evaled()); + CHECK(x.is_available()); x = ones({2, 3}); x.set_tracer(true); eval(x); CHECK(!x.has_primitive()); - CHECK(x.is_evaled()); + CHECK(x.is_available()); } TEST_CASE("test eval graph retention when not tracing") { @@ -74,20 +74,20 @@ TEST_CASE("test eval graph retention when not tracing") { auto z = x + y; eval(z); CHECK(!z.has_primitive()); - CHECK(z.is_evaled()); + CHECK(z.is_available()); CHECK_EQ(z.item(), 3); z.set_tracer(false); CHECK_EQ(z.item(), 3); CHECK(!z.has_primitive()); - CHECK(z.is_evaled()); + CHECK(z.is_available()); z = x + y; auto a = z + x; auto b = a + y; eval(b); CHECK(!z.has_primitive()); - CHECK(z.is_evaled()); + CHECK(z.is_available()); CHECK(!a.has_primitive()); - CHECK(a.is_evaled()); + CHECK(a.is_available()); } diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index c7a0c8c14..976317f2f 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -495,12 +495,14 @@ TEST_CASE("test metal memory info") { // Query active and peak memory { - auto a = zeros({4096}); + // Do these tests on the CPU since deallocation is synchronized + // with the main thread. + auto a = zeros({4096}, Device::cpu); eval(a); auto active_mem = metal::get_active_memory(); CHECK(active_mem >= 4096 * 4); { - auto b = zeros({4096}); + auto b = zeros({4096}, Device::cpu); eval(b); } auto new_active_mem = metal::get_active_memory();