mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Fix array is_available race cases (#1468)
This commit is contained in:
parent
9b12093739
commit
3274c6a087
@ -14,7 +14,7 @@ option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
|||||||
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" OFF)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
@ -47,13 +47,13 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||||
)
|
)
|
||||||
else()
|
else()
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
|
||||||
set(MLX_BUILD_METAL ON)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -95,13 +95,29 @@ void array::detach() {
|
|||||||
array_desc_->primitive = nullptr;
|
array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::eval() {
|
bool array::is_available() const {
|
||||||
// Ensure the array is ready to be read
|
if (status() == Status::available) {
|
||||||
if (status() == Status::scheduled) {
|
return true;
|
||||||
|
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||||
|
set_status(Status::available);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::wait() {
|
||||||
|
if (!is_available()) {
|
||||||
event().wait();
|
event().wait();
|
||||||
set_status(Status::available);
|
set_status(Status::available);
|
||||||
} else if (status() == Status::unscheduled) {
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::eval() {
|
||||||
|
// Ensure the array is ready to be read
|
||||||
|
if (status() == Status::unscheduled) {
|
||||||
mlx::core::eval({*this});
|
mlx::core::eval({*this});
|
||||||
|
} else {
|
||||||
|
wait();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
30
mlx/array.h
30
mlx/array.h
@ -344,11 +344,33 @@ class array {
|
|||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Status { unscheduled, scheduled, available };
|
enum Status {
|
||||||
|
// The ouptut of a computation which has not been scheduled.
|
||||||
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
|
unscheduled,
|
||||||
|
|
||||||
bool is_available() const {
|
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||||
return status() == Status::available;
|
// not yet been called on the array's primitive. A possible
|
||||||
}
|
// status of `x` in `auto x = a + b; eval(x);`
|
||||||
|
scheduled,
|
||||||
|
|
||||||
|
// The array's `eval_*` function has been run, but the computation is not
|
||||||
|
// necessarily complete. The array will have memory allocated and if it is
|
||||||
|
// not a tracer then it will be detached from the graph.
|
||||||
|
evaluated,
|
||||||
|
|
||||||
|
// If the array is the output of a computation then the computation
|
||||||
|
// is complete. Constant arrays are always available (e.g. `array({1, 2,
|
||||||
|
// 3})`)
|
||||||
|
available
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the array is safe to read.
|
||||||
|
bool is_available() const;
|
||||||
|
|
||||||
|
// Wait on the array to be available. After this `is_available` returns
|
||||||
|
// `true`.
|
||||||
|
void wait();
|
||||||
|
|
||||||
Status status() const {
|
Status status() const {
|
||||||
return array_desc_->status;
|
return array_desc_->status;
|
||||||
|
@ -27,4 +27,9 @@ void Event::signal() {
|
|||||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Event::is_signaled() const {
|
||||||
|
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
|
||||||
|
value();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -74,7 +74,7 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
for (auto& out : outputs) {
|
for (auto& out : outputs) {
|
||||||
out.set_status(array::Status::available);
|
out.set_status(array::Status::evaluated);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||||
|
@ -36,4 +36,11 @@ void Event::signal() {
|
|||||||
ec->cv.notify_all();
|
ec->cv.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Event::is_signaled() const {
|
||||||
|
auto ec = static_cast<EventCounter*>(raw_event().get());
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lk(ec->mtx);
|
||||||
|
return (ec->value > value());
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -20,6 +20,9 @@ class Event {
|
|||||||
// Signal the event at its current value
|
// Signal the event at its current value
|
||||||
void signal();
|
void signal();
|
||||||
|
|
||||||
|
// Check if the event has been signaled at its current value
|
||||||
|
bool is_signaled() const;
|
||||||
|
|
||||||
// Check if the event is valid
|
// Check if the event is valid
|
||||||
bool valid() const {
|
bool valid() const {
|
||||||
return event_ != nullptr;
|
return event_ != nullptr;
|
||||||
|
@ -79,7 +79,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!in.is_available()) {
|
if (in.status() == array::Status::unscheduled) {
|
||||||
if (async && in.is_tracer()) {
|
if (async && in.is_tracer()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[async_eval] Not allowed inside a graph transformation.");
|
"[async_eval] Not allowed inside a graph transformation.");
|
||||||
@ -115,7 +115,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// All inputs are done being processed, process this array
|
// All inputs are done being processed, process this array
|
||||||
if (a.is_available() && !a.is_tracer() && a.has_primitive()) {
|
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
||||||
|
a.has_primitive()) {
|
||||||
// If the array is evaluated and is no longer a tracer, detach it
|
// If the array is evaluated and is no longer a tracer, detach it
|
||||||
a.detach();
|
a.detach();
|
||||||
} else if (a.status() == array::Status::unscheduled) {
|
} else if (a.status() == array::Status::unscheduled) {
|
||||||
@ -208,14 +209,12 @@ void eval(std::vector<array> outputs) {
|
|||||||
return x.status() == array::Status::unscheduled;
|
return x.status() == array::Status::unscheduled;
|
||||||
})) {
|
})) {
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
if (!x.is_available()) {
|
x.wait();
|
||||||
x.event().wait();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_impl(std::move(outputs), false).event().wait();
|
eval_impl(std::move(outputs), false).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
@ -51,11 +51,13 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
|
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
|
||||||
|
|
||||||
def test_async_eval_twice(self):
|
def test_async_eval_twice(self):
|
||||||
|
for _ in range(1000):
|
||||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
mx.async_eval(x)
|
mx.async_eval(x)
|
||||||
y = x + 1
|
y = x + 1
|
||||||
mx.async_eval(y)
|
mx.async_eval(y)
|
||||||
self.assertEqual(x.item(), 3)
|
self.assertEqual(x.item(), 3)
|
||||||
|
self.assertEqual(y.item(), 4)
|
||||||
|
|
||||||
def test_async_eval_in_trace(self):
|
def test_async_eval_in_trace(self):
|
||||||
def fun(x):
|
def fun(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user