mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
parent
b18468bf81
commit
8a0677d56d
@ -82,8 +82,10 @@ elseif (MLX_BUILD_METAL)
|
|||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||||
|
|
||||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||||
|
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||||
|
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||||
@ -92,6 +94,7 @@ elseif (MLX_BUILD_METAL)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
metal_cpp
|
metal_cpp
|
||||||
URL ${METAL_CPP_URL}
|
URL ${METAL_CPP_URL}
|
||||||
|
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
|
||||||
)
|
)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
|
36
cmake/metal.14.0.diff
Normal file
36
cmake/metal.14.0.diff
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||||
|
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||||
|
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||||
|
@@ -62,6 +62,7 @@
|
||||||
|
|
||||||
|
uint64_t signaledValue() const;
|
||||||
|
void setSignaledValue(uint64_t signaledValue);
|
||||||
|
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||||
|
};
|
||||||
|
|
||||||
|
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||||
|
@@ -138,6 +139,11 @@
|
||||||
|
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||||
|
{
|
||||||
|
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+// method: waitUntilSignaledValue
|
||||||
|
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||||
|
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// static method: alloc
|
||||||
|
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||||
|
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||||
|
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||||
|
@@ -1906,6 +1906,9 @@
|
||||||
|
"setShouldMaximizeConcurrentCompilation:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||||
|
"setSignaledValue:");
|
||||||
|
+_MTL_PRIVATE_DEF_SEL(
|
||||||
|
+ waitUntilSignaledValue_timeoutMS_,
|
||||||
|
+ "waitUntilSignaledValue:timeoutMS:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||||
|
"setSize:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
36
cmake/metal.14.2.diff
Normal file
36
cmake/metal.14.2.diff
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||||
|
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||||
|
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||||
|
@@ -62,6 +62,7 @@
|
||||||
|
|
||||||
|
uint64_t signaledValue() const;
|
||||||
|
void setSignaledValue(uint64_t signaledValue);
|
||||||
|
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||||
|
};
|
||||||
|
|
||||||
|
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||||
|
@@ -138,6 +139,11 @@
|
||||||
|
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||||
|
{
|
||||||
|
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+// method: waitUntilSignaledValue
|
||||||
|
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||||
|
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// static method: alloc
|
||||||
|
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||||
|
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||||
|
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||||
|
@@ -1918,6 +1918,9 @@
|
||||||
|
"setShouldMaximizeConcurrentCompilation:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||||
|
"setSignaledValue:");
|
||||||
|
+_MTL_PRIVATE_DEF_SEL(
|
||||||
|
+ waitUntilSignaledValue_timeoutMS_,
|
||||||
|
+ "waitUntilSignaledValue:timeoutMS:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||||
|
"setSize:");
|
||||||
|
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@ -93,7 +93,11 @@ void array::detach() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void array::eval() {
|
void array::eval() {
|
||||||
if (!is_evaled()) {
|
// Ensure the array is ready to be read
|
||||||
|
if (status() == Status::scheduled) {
|
||||||
|
event().wait();
|
||||||
|
set_status(Status::available);
|
||||||
|
} else if (status() == Status::unscheduled) {
|
||||||
mlx::core::eval({*this});
|
mlx::core::eval({*this});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -176,7 +180,7 @@ void array::ArrayDesc::init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||||
: shape(std::move(shape)), dtype(dtype) {
|
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,6 +191,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
std::vector<array> inputs)
|
std::vector<array> inputs)
|
||||||
: shape(std::move(shape)),
|
: shape(std::move(shape)),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
|
status(Status::unscheduled),
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
init();
|
init();
|
||||||
|
33
mlx/array.h
33
mlx/array.h
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
|
#include "mlx/event.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -315,9 +316,27 @@ class array {
|
|||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if the array has been evaluated
|
enum Status { unscheduled, scheduled, available };
|
||||||
bool is_evaled() const {
|
|
||||||
return array_desc_->data != nullptr;
|
bool is_available() const {
|
||||||
|
return status() == Status::available;
|
||||||
|
}
|
||||||
|
const Status status() const {
|
||||||
|
return array_desc_->status;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_status(Status s) const {
|
||||||
|
array_desc_->status = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the array's shared event
|
||||||
|
Event& event() const {
|
||||||
|
return array_desc_->event;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach an event to a not yet evaluated array
|
||||||
|
void attach_event(Event e) const {
|
||||||
|
array_desc_->event = std::move(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark the array as a tracer array (true) or not.
|
// Mark the array as a tracer array (true) or not.
|
||||||
@ -370,6 +389,11 @@ class array {
|
|||||||
Dtype dtype;
|
Dtype dtype;
|
||||||
std::shared_ptr<Primitive> primitive;
|
std::shared_ptr<Primitive> primitive;
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
|
||||||
|
// An event on the array used for synchronization
|
||||||
|
Event event;
|
||||||
|
|
||||||
// Indicates an array is being used in a graph transform
|
// Indicates an array is being used in a graph transform
|
||||||
// and should not be detached from the graph
|
// and should not be detached from the graph
|
||||||
bool is_tracer{false};
|
bool is_tracer{false};
|
||||||
@ -470,10 +494,11 @@ T array::item() const {
|
|||||||
if (size() != 1) {
|
if (size() != 1) {
|
||||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||||
}
|
}
|
||||||
if (!is_evaled()) {
|
if (status() == Status::unscheduled) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"item() const can only be called on evaled arrays");
|
"item() const can only be called on evaled arrays");
|
||||||
}
|
}
|
||||||
|
const_cast<array*>(this)->eval();
|
||||||
return *data<T>();
|
return *data<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
@ -544,11 +544,12 @@ Device& device(mlx::core::Device) {
|
|||||||
return metal_device;
|
return metal_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||||
auto dtor = [](void* ptr) {
|
auto dtor = [](void* ptr) {
|
||||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||||
};
|
};
|
||||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
return std::unique_ptr<void, std::function<void(void*)>>(
|
||||||
|
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream) {
|
void new_stream(Stream stream) {
|
||||||
|
30
mlx/backend/metal/event.cpp
Normal file
30
mlx/backend/metal/event.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Event::Event(const Stream& stream) : stream_(stream) {
|
||||||
|
auto dtor = [](void* ptr) {
|
||||||
|
auto p = metal::new_scoped_memory_pool();
|
||||||
|
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||||
|
};
|
||||||
|
auto p = metal::new_scoped_memory_pool();
|
||||||
|
event_ = std::shared_ptr<void>(
|
||||||
|
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Event::wait() {
|
||||||
|
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
|
||||||
|
->waitUntilSignaledValue(value(), -1)) {
|
||||||
|
throw std::runtime_error("[Event::wait] Timed out");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Event::signal() {
|
||||||
|
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -55,17 +55,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::function<void()> make_task(array arr, bool signal) {
|
||||||
array& arr,
|
auto task = [arr = std::move(arr), signal]() mutable {
|
||||||
std::vector<std::shared_future<void>> deps,
|
|
||||||
std::shared_ptr<std::promise<void>> p) {
|
|
||||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& d : deps) {
|
|
||||||
d.wait();
|
|
||||||
}
|
|
||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto command_buffer = increment_command_buffer(s);
|
auto command_buffer = increment_command_buffer(s);
|
||||||
|
for (auto& input : arr.inputs()) {
|
||||||
|
if (input.event().valid() &&
|
||||||
|
input.event().stream() != arr.primitive().stream()) {
|
||||||
|
// TODO, consider committing the buffer and encoding a wait in the new
|
||||||
|
// buffer rather than on the task thread
|
||||||
|
input.event().wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto outputs = arr.outputs();
|
auto outputs = arr.outputs();
|
||||||
{
|
{
|
||||||
// If the array is a tracer hold a reference
|
// If the array is a tracer hold a reference
|
||||||
@ -88,13 +91,16 @@ std::function<void()> make_task(
|
|||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
if (p) {
|
|
||||||
|
if (signal) {
|
||||||
metal::device(s.device).end_encoding(s.index);
|
metal::device(s.device).end_encoding(s.index);
|
||||||
|
command_buffer->encodeSignalEvent(
|
||||||
|
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
||||||
|
arr.event().value());
|
||||||
scheduler::notify_new_task(s);
|
scheduler::notify_new_task(s);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[s, buffers = std::move(buffers), p = std::move(p)](
|
[s, buffers = std::move(buffers), event = arr.event()](
|
||||||
MTL::CommandBuffer* cbuf) {
|
MTL::CommandBuffer* cbuf) {
|
||||||
p->set_value();
|
|
||||||
scheduler::notify_task_completion(s);
|
scheduler::notify_task_completion(s);
|
||||||
check_error(cbuf);
|
check_error(cbuf);
|
||||||
});
|
});
|
||||||
|
@ -2,9 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <future>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
@ -12,11 +10,9 @@
|
|||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
void new_stream(Stream stream);
|
void new_stream(Stream stream);
|
||||||
std::shared_ptr<void> new_scoped_memory_pool();
|
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
array& arr,
|
|
||||||
std::vector<std::shared_future<void>> deps,
|
std::function<void()> make_task(array arr, bool signal);
|
||||||
std::shared_ptr<std::promise<void>> p);
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -2,6 +2,7 @@ target_sources(
|
|||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
)
|
)
|
||||||
|
39
mlx/backend/no_metal/event.cpp
Normal file
39
mlx/backend/no_metal/event.cpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/event.h"
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct EventCounter {
|
||||||
|
uint64_t value{0};
|
||||||
|
std::mutex mtx;
|
||||||
|
std::condition_variable cv;
|
||||||
|
};
|
||||||
|
|
||||||
|
Event::Event(const Stream& stream) : stream_(stream) {
|
||||||
|
auto dtor = [](void* ptr) { delete static_cast<EventCounter*>(ptr); };
|
||||||
|
event_ = std::shared_ptr<void>(new EventCounter{}, dtor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Event::wait() {
|
||||||
|
auto ec = static_cast<EventCounter*>(raw_event().get());
|
||||||
|
std::unique_lock<std::mutex> lk(ec->mtx);
|
||||||
|
if (ec->value >= value()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Event::signal() {
|
||||||
|
auto ec = static_cast<EventCounter*>(raw_event().get());
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lk(ec->mtx);
|
||||||
|
ec->value = value();
|
||||||
|
}
|
||||||
|
ec->cv.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -12,14 +12,12 @@ bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream) {}
|
void new_stream(Stream) {}
|
||||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
|
||||||
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::function<void()> make_task(array arr, bool signal) {
|
||||||
array& arr,
|
|
||||||
std::vector<std::shared_future<void>> deps,
|
|
||||||
std::shared_ptr<std::promise<void>> p) {
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||||
}
|
}
|
||||||
|
@ -352,7 +352,8 @@ void compile_simplify(
|
|||||||
// Helpers to identify identical scalars
|
// Helpers to identify identical scalars
|
||||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||||
auto is_scalar = [](const array& a) {
|
auto is_scalar = [](const array& a) {
|
||||||
return a.is_evaled() && a.ndim() == 0;
|
// Condition for when it's safe to read an array
|
||||||
|
return a.is_available() && a.ndim() == 0;
|
||||||
};
|
};
|
||||||
auto get_scalar_rep = [](const array& a) {
|
auto get_scalar_rep = [](const array& a) {
|
||||||
uint64_t v = 0;
|
uint64_t v = 0;
|
||||||
|
56
mlx/event.h
Normal file
56
mlx/event.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
class Event {
|
||||||
|
public:
|
||||||
|
Event(){};
|
||||||
|
|
||||||
|
Event(const Stream& steam);
|
||||||
|
|
||||||
|
// Wait for the event to be signaled at its curent value
|
||||||
|
void wait();
|
||||||
|
|
||||||
|
// Signal the event at its current value
|
||||||
|
void signal();
|
||||||
|
|
||||||
|
// Check if the event is valid
|
||||||
|
bool valid() {
|
||||||
|
return event_ != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint64_t value() {
|
||||||
|
return value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void set_value(uint64_t v) {
|
||||||
|
value_ = v;
|
||||||
|
};
|
||||||
|
|
||||||
|
const Stream& stream() {
|
||||||
|
if (!valid()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Event::stream] Cannot access stream on invalid event.");
|
||||||
|
}
|
||||||
|
return stream_;
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::shared_ptr<void>& raw_event() {
|
||||||
|
return event_;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Default constructed stream should never be used
|
||||||
|
// since the event is not yet valid
|
||||||
|
Stream stream_{0, Device::cpu};
|
||||||
|
std::shared_ptr<void> event_{nullptr};
|
||||||
|
uint64_t value_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -36,29 +36,32 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation.
|
// are currently under a function transformation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
std::shared_future<void> async_eval(std::vector<array> outputs) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
static std::shared_future<void> global_synchronizer;
|
|
||||||
// Catch up with previous async eval if needed
|
|
||||||
if (global_synchronizer.valid()) {
|
|
||||||
global_synchronizer.wait();
|
|
||||||
}
|
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
|
||||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
// stream events to use for synchronization
|
||||||
|
std::unordered_map<uint32_t, Event> events;
|
||||||
|
|
||||||
// Make an effort to choose a good output stream
|
// Make an effort to choose a good output stream
|
||||||
Stream stream = default_stream(default_device());
|
Stream stream = default_stream(default_device());
|
||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
if (!o.is_evaled() && o.has_primitive()) {
|
if (o.status() == array::Status::unscheduled && o.has_primitive()) {
|
||||||
stream = o.primitive().stream();
|
stream = o.primitive().stream();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_set<uintptr_t> needs_signal;
|
||||||
|
|
||||||
auto synchronizer = array(
|
auto synchronizer = array(
|
||||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||||
|
needs_signal.insert(synchronizer.id());
|
||||||
|
|
||||||
|
// Make an event for the synchronizer stream
|
||||||
|
events.emplace(stream.index, Event{stream});
|
||||||
|
|
||||||
{
|
{
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
||||||
dfs.emplace(synchronizer, 0);
|
dfs.emplace(synchronizer, 0);
|
||||||
while (!dfs.empty()) {
|
while (!dfs.empty()) {
|
||||||
@ -67,16 +70,23 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
|||||||
if (idx < a.inputs().size()) {
|
if (idx < a.inputs().size()) {
|
||||||
// Add an input, and continue
|
// Add an input, and continue
|
||||||
auto& in = a.inputs()[idx++];
|
auto& in = a.inputs()[idx++];
|
||||||
if (!in.is_evaled()) {
|
|
||||||
|
// Ignore arrays already scheduled
|
||||||
|
if (in.status() == array::Status::scheduled) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!in.is_available()) {
|
||||||
|
if (async && in.is_tracer()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[async_eval] Not allowed inside a graph transformation.");
|
||||||
|
}
|
||||||
if (!in.has_primitive()) {
|
if (!in.has_primitive()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[eval] Attempting to eval an array without a primitive.");
|
"[eval] Attempting to eval an array without a primitive.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input is being computed on a different stream, we need to
|
|
||||||
// manage the dependency.
|
|
||||||
if (a.primitive().stream() != in.primitive().stream()) {
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
needs_signal.insert(in.id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,52 +101,54 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// All inputs are done being processed, process this array
|
// All inputs are done being processed, process this array
|
||||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
if (a.is_available() && !a.is_tracer() && a.has_primitive()) {
|
||||||
|
// If the array is evaluated and is no longer a tracer, detach it
|
||||||
|
a.detach();
|
||||||
|
} else if (a.status() == array::Status::unscheduled) {
|
||||||
tape.push(a);
|
tape.push(a);
|
||||||
|
// Lookup corresponding event and increment counter
|
||||||
|
auto& stream = a.primitive().stream();
|
||||||
|
auto e = events.find(stream.index);
|
||||||
|
if (e == events.end()) {
|
||||||
|
e = events.emplace(stream.index, Event{stream}).first;
|
||||||
|
}
|
||||||
|
e->second.set_value(e->second.value() + 1);
|
||||||
|
a.attach_event(e->second);
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
s.attach_event(e->second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dfs.pop();
|
dfs.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.front());
|
auto arr = std::move(tape.front());
|
||||||
tape.pop();
|
tape.pop();
|
||||||
if (arr.is_evaled()) {
|
|
||||||
if (!arr.is_tracer() && arr.has_primitive()) {
|
// Set the status of the array and siblings.
|
||||||
arr.detach();
|
auto status = async ? array::Status::scheduled : array::Status::available;
|
||||||
}
|
arr.set_status(status);
|
||||||
continue;
|
for (auto& s : arr.siblings()) {
|
||||||
|
s.set_status(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = arr.primitive().stream();
|
auto stream = arr.primitive().stream();
|
||||||
std::vector<std::shared_future<void>> arr_deps;
|
std::vector<std::shared_future<void>> arr_deps;
|
||||||
for (auto& in : arr.inputs()) {
|
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||||
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
|
|
||||||
arr_deps.push_back(it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::shared_ptr<std::promise<void>> p;
|
|
||||||
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
|
||||||
p = std::make_shared<std::promise<void>>();
|
|
||||||
ps.push_back(p);
|
|
||||||
it->second = p->get_future().share();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arr.primitive().device() == Device::gpu) {
|
if (arr.primitive().device() == Device::gpu) {
|
||||||
if (!metal::is_available()) {
|
if (!metal::is_available()) {
|
||||||
throw std::runtime_error("Metal GPU is not available.");
|
throw std::runtime_error("Metal GPU is not available.");
|
||||||
}
|
}
|
||||||
scheduler::enqueue(
|
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
||||||
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
|
|
||||||
} else {
|
} else {
|
||||||
auto task = [arr,
|
auto task = [arr = std::move(arr), stream, signal]() mutable {
|
||||||
stream,
|
for (auto& input : arr.inputs()) {
|
||||||
deps = std::move(arr_deps),
|
if (input.event().valid() &&
|
||||||
p = std::move(p)]() mutable {
|
input.event().stream() != arr.primitive().stream()) {
|
||||||
for (auto& d : deps) {
|
input.event().wait();
|
||||||
d.wait();
|
}
|
||||||
}
|
}
|
||||||
scheduler::notify_new_task(stream);
|
scheduler::notify_new_task(stream);
|
||||||
auto outputs = arr.outputs();
|
auto outputs = arr.outputs();
|
||||||
@ -144,20 +156,24 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
|||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
if (p) {
|
if (signal) {
|
||||||
p->set_value();
|
arr.event().signal();
|
||||||
}
|
}
|
||||||
|
|
||||||
scheduler::notify_task_completion(stream);
|
scheduler::notify_task_completion(stream);
|
||||||
};
|
};
|
||||||
scheduler::enqueue(stream, std::move(task));
|
scheduler::enqueue(stream, std::move(task));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
global_synchronizer = std::move(deps[synchronizer.id()]);
|
return synchronizer;
|
||||||
return global_synchronizer;
|
}
|
||||||
|
|
||||||
|
void async_eval(std::vector<array> outputs) {
|
||||||
|
eval_impl(std::move(outputs), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(std::vector<array> outputs) {
|
void eval(std::vector<array> outputs) {
|
||||||
async_eval(std::move(outputs)).wait();
|
eval_impl(std::move(outputs), false).event().wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <future>
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::shared_future<void> async_eval(std::vector<array> outputs);
|
void async_eval(std::vector<array> outputs);
|
||||||
|
|
||||||
void eval(std::vector<array> outputs);
|
void eval(std::vector<array> outputs);
|
||||||
|
|
||||||
|
@ -264,9 +264,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, array a) {
|
std::ostream& operator<<(std::ostream& os, array a) {
|
||||||
if (!a.is_evaled()) {
|
a.eval();
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
print_array<bool>(os, a);
|
print_array<bool>(os, a);
|
||||||
|
@ -946,10 +946,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
[](array& a) {
|
[](array& a) {
|
||||||
if (!a.is_evaled()) {
|
nb::gil_scoped_release nogil;
|
||||||
nb::gil_scoped_release nogil;
|
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << a;
|
os << a;
|
||||||
return os.str();
|
return os.str();
|
||||||
|
@ -86,7 +86,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
|||||||
std::memset(view, 0, sizeof(Py_buffer));
|
std::memset(view, 0, sizeof(Py_buffer));
|
||||||
auto a = nb::cast<array>(nb::handle(obj));
|
auto a = nb::cast<array>(nb::handle(obj));
|
||||||
|
|
||||||
if (!a.is_evaled()) {
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
|
@ -104,8 +104,7 @@ template <typename Lib, typename T>
|
|||||||
nb::ndarray<Lib> mlx_to_nd_array(
|
nb::ndarray<Lib> mlx_to_nd_array(
|
||||||
array a,
|
array a,
|
||||||
std::optional<nb::dlpack::dtype> t = {}) {
|
std::optional<nb::dlpack::dtype> t = {}) {
|
||||||
// Eval if not already evaled
|
{
|
||||||
if (!a.is_evaled()) {
|
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
|
@ -595,14 +595,6 @@ class PyCheckpointedFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void init_transforms(nb::module_& m) {
|
void init_transforms(nb::module_& m) {
|
||||||
nb::class_<std::shared_future<void>>(
|
|
||||||
m,
|
|
||||||
"Synchronizer",
|
|
||||||
R"pbdoc(
|
|
||||||
A synchronization object returned by :func:`async_eval`.
|
|
||||||
)pbdoc")
|
|
||||||
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
|
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const nb::args& args) {
|
[](const nb::args& args) {
|
||||||
@ -629,19 +621,14 @@ void init_transforms(nb::module_& m) {
|
|||||||
std::vector<array> arrays = tree_flatten(args, false);
|
std::vector<array> arrays = tree_flatten(args, false);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
return async_eval(arrays);
|
async_eval(arrays);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::sig("def async_eval(*args) -> Synchronizer"),
|
nb::sig("def async_eval(*args)"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
You must call ``wait`` on the returned synchronization object before
|
|
||||||
using any arrays that are asynchronously evaluated.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
This is an experimental API and may change in future versions.
|
This is an experimental API and may change in future versions.
|
||||||
@ -652,8 +639,17 @@ void init_transforms(nb::module_& m) {
|
|||||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||||
arrays are ignored.
|
arrays are ignored.
|
||||||
|
|
||||||
Returns:
|
Example:
|
||||||
Synchronizer: A synchronization object.
|
>>> x = mx.array(1.0)
|
||||||
|
>>> y = mx.exp(x)
|
||||||
|
>>> mx.async_eval(y)
|
||||||
|
>>> print(y)
|
||||||
|
>>>
|
||||||
|
>>> y = mx.exp(x)
|
||||||
|
>>> mx.async_eval(y)
|
||||||
|
>>> z = y + 3
|
||||||
|
>>> mx.async_eval(z)
|
||||||
|
>>> print(z)
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"jvp",
|
"jvp",
|
||||||
|
@ -34,16 +34,75 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_async_eval(self):
|
def test_async_eval(self):
|
||||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
sync = mx.async_eval(x)
|
mx.async_eval(x)
|
||||||
sync.wait()
|
|
||||||
self.assertEqual(x.item(), 3)
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
# It should be safe to call eval on the array which has been async
|
# It should be safe to call eval on the array which has been async
|
||||||
# eval'ed
|
# eval'ed
|
||||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
sync = mx.async_eval(x)
|
|
||||||
self.assertEqual(x.item(), 3)
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
y = 2 * x
|
||||||
|
mx.async_eval(y)
|
||||||
|
z = 2 * y
|
||||||
|
mx.async_eval(z)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6])))
|
||||||
|
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
|
||||||
|
|
||||||
|
def test_async_eval_twice(self):
|
||||||
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
|
mx.async_eval(x)
|
||||||
|
y = x + 1
|
||||||
|
mx.async_eval(y)
|
||||||
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
|
def test_async_eval_in_trace(self):
|
||||||
|
def fun(x):
|
||||||
|
y = x + 1.0
|
||||||
|
mx.async_eval(y)
|
||||||
|
return mx.exp(y)
|
||||||
|
|
||||||
|
# Raises
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.grad(fun)(mx.array(1.0))
|
||||||
|
|
||||||
|
# Also raises
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.vmap(fun)(mx.ones((2, 2)))
|
||||||
|
|
||||||
|
def test_async_eval_into_eval(self):
|
||||||
|
x = mx.array(1)
|
||||||
|
y = x + 1
|
||||||
|
mx.async_eval(y)
|
||||||
|
a = y - 10
|
||||||
|
b = mx.abs(a)
|
||||||
|
self.assertEqual(b.item(), 8)
|
||||||
|
|
||||||
|
def test_async_eval_into_eval_diff_stream(self):
|
||||||
|
s = mx.new_stream(mx.cpu)
|
||||||
|
x = mx.array(0)
|
||||||
|
y = x - 5
|
||||||
|
mx.async_eval(y)
|
||||||
|
z = mx.abs(y, stream=s)
|
||||||
|
self.assertEqual(z.item(), 5)
|
||||||
|
|
||||||
|
def test_eval_slow_fast_multi_stream(self):
|
||||||
|
x = mx.ones((8000,))
|
||||||
|
y = mx.abs(mx.array(-1.0))
|
||||||
|
for _ in range(20):
|
||||||
|
x = x + mx.array(1.0)
|
||||||
|
z = mx.add(x, y, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||||
|
|
||||||
|
# Switch eval order
|
||||||
|
x = mx.ones((8000,))
|
||||||
|
y = mx.abs(mx.array(-1.0))
|
||||||
|
for _ in range(20):
|
||||||
|
x = x + mx.array(1.0)
|
||||||
|
z = mx.add(y, x, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -24,12 +24,12 @@ class TestMetal(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||||
|
|
||||||
# Query active and peak memory
|
# Query active and peak memory
|
||||||
a = mx.zeros((4096,))
|
a = mx.zeros((4096,), stream=mx.cpu)
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
active_mem = mx.metal.get_active_memory()
|
active_mem = mx.metal.get_active_memory()
|
||||||
self.assertTrue(active_mem >= 4096 * 4)
|
self.assertTrue(active_mem >= 4096 * 4)
|
||||||
|
|
||||||
b = mx.zeros((4096,))
|
b = mx.zeros((4096,), stream=mx.cpu)
|
||||||
mx.eval(b)
|
mx.eval(b)
|
||||||
del b
|
del b
|
||||||
|
|
||||||
|
@ -49,7 +49,6 @@ TEST_CASE("test arg reduce small") {
|
|||||||
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
||||||
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
||||||
{2, 3, 4});
|
{2, 3, 4});
|
||||||
x.eval();
|
|
||||||
test_arg_reduce_small(
|
test_arg_reduce_small(
|
||||||
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
|
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
|
||||||
test_arg_reduce_small(
|
test_arg_reduce_small(
|
||||||
|
@ -100,7 +100,7 @@ TEST_CASE("test jvp") {
|
|||||||
auto fun1 = [](const array& x) {
|
auto fun1 = [](const array& x) {
|
||||||
auto y = 3 * x;
|
auto y = 3 * x;
|
||||||
eval(y);
|
eval(y);
|
||||||
CHECK(y.is_evaled());
|
CHECK(y.is_available());
|
||||||
CHECK(y.has_primitive());
|
CHECK(y.has_primitive());
|
||||||
CHECK(y.is_tracer());
|
CHECK(y.is_tracer());
|
||||||
return 2 * y;
|
return 2 * y;
|
||||||
@ -253,7 +253,7 @@ TEST_CASE("test grad") {
|
|||||||
eval(y);
|
eval(y);
|
||||||
CHECK(x.is_tracer());
|
CHECK(x.is_tracer());
|
||||||
CHECK(!y.is_tracer());
|
CHECK(!y.is_tracer());
|
||||||
CHECK(y.is_evaled());
|
CHECK(y.is_available());
|
||||||
CHECK(!y.has_primitive());
|
CHECK(!y.has_primitive());
|
||||||
return square(x);
|
return square(x);
|
||||||
};
|
};
|
||||||
@ -265,7 +265,7 @@ TEST_CASE("test grad") {
|
|||||||
x = x + 2.0f;
|
x = x + 2.0f;
|
||||||
eval(x);
|
eval(x);
|
||||||
CHECK(x.is_tracer());
|
CHECK(x.is_tracer());
|
||||||
CHECK(x.is_evaled());
|
CHECK(x.is_available());
|
||||||
CHECK(x.has_primitive());
|
CHECK(x.has_primitive());
|
||||||
return square(x);
|
return square(x);
|
||||||
};
|
};
|
||||||
@ -1259,7 +1259,7 @@ TEST_CASE("test update state") {
|
|||||||
grad(fn)(y);
|
grad(fn)(y);
|
||||||
eval(state);
|
eval(state);
|
||||||
CHECK(!state.has_primitive());
|
CHECK(!state.has_primitive());
|
||||||
CHECK(state.is_evaled());
|
CHECK(state.is_available());
|
||||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,13 +56,13 @@ TEST_CASE("test eval with tracer when not tracing") {
|
|||||||
CHECK(!x.is_tracer());
|
CHECK(!x.is_tracer());
|
||||||
eval(x);
|
eval(x);
|
||||||
CHECK(!x.has_primitive());
|
CHECK(!x.has_primitive());
|
||||||
CHECK(x.is_evaled());
|
CHECK(x.is_available());
|
||||||
|
|
||||||
x = ones({2, 3});
|
x = ones({2, 3});
|
||||||
x.set_tracer(true);
|
x.set_tracer(true);
|
||||||
eval(x);
|
eval(x);
|
||||||
CHECK(!x.has_primitive());
|
CHECK(!x.has_primitive());
|
||||||
CHECK(x.is_evaled());
|
CHECK(x.is_available());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test eval graph retention when not tracing") {
|
TEST_CASE("test eval graph retention when not tracing") {
|
||||||
@ -74,20 +74,20 @@ TEST_CASE("test eval graph retention when not tracing") {
|
|||||||
auto z = x + y;
|
auto z = x + y;
|
||||||
eval(z);
|
eval(z);
|
||||||
CHECK(!z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_available());
|
||||||
CHECK_EQ(z.item<int>(), 3);
|
CHECK_EQ(z.item<int>(), 3);
|
||||||
|
|
||||||
z.set_tracer(false);
|
z.set_tracer(false);
|
||||||
CHECK_EQ(z.item<int>(), 3);
|
CHECK_EQ(z.item<int>(), 3);
|
||||||
CHECK(!z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_available());
|
||||||
|
|
||||||
z = x + y;
|
z = x + y;
|
||||||
auto a = z + x;
|
auto a = z + x;
|
||||||
auto b = a + y;
|
auto b = a + y;
|
||||||
eval(b);
|
eval(b);
|
||||||
CHECK(!z.has_primitive());
|
CHECK(!z.has_primitive());
|
||||||
CHECK(z.is_evaled());
|
CHECK(z.is_available());
|
||||||
CHECK(!a.has_primitive());
|
CHECK(!a.has_primitive());
|
||||||
CHECK(a.is_evaled());
|
CHECK(a.is_available());
|
||||||
}
|
}
|
||||||
|
@ -495,12 +495,14 @@ TEST_CASE("test metal memory info") {
|
|||||||
|
|
||||||
// Query active and peak memory
|
// Query active and peak memory
|
||||||
{
|
{
|
||||||
auto a = zeros({4096});
|
// Do these tests on the CPU since deallocation is synchronized
|
||||||
|
// with the main thread.
|
||||||
|
auto a = zeros({4096}, Device::cpu);
|
||||||
eval(a);
|
eval(a);
|
||||||
auto active_mem = metal::get_active_memory();
|
auto active_mem = metal::get_active_memory();
|
||||||
CHECK(active_mem >= 4096 * 4);
|
CHECK(active_mem >= 4096 * 4);
|
||||||
{
|
{
|
||||||
auto b = zeros({4096});
|
auto b = zeros({4096}, Device::cpu);
|
||||||
eval(b);
|
eval(b);
|
||||||
}
|
}
|
||||||
auto new_active_mem = metal::get_active_memory();
|
auto new_active_mem = metal::get_active_memory();
|
||||||
|
Loading…
Reference in New Issue
Block a user