mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
33
mlx/array.h
33
mlx/array.h
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -315,9 +316,27 @@ class array {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
|
||||
// Check if the array has been evaluated
|
||||
bool is_evaled() const {
|
||||
return array_desc_->data != nullptr;
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
void set_status(Status s) const {
|
||||
array_desc_->status = s;
|
||||
}
|
||||
|
||||
// Get the array's shared event
|
||||
Event& event() const {
|
||||
return array_desc_->event;
|
||||
}
|
||||
|
||||
// Attach an event to a not yet evaluated array
|
||||
void attach_event(Event e) const {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
@@ -370,6 +389,11 @@ class array {
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
|
||||
Status status;
|
||||
|
||||
// An event on the array used for synchronization
|
||||
Event event;
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
bool is_tracer{false};
|
||||
@@ -470,10 +494,11 @@ T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
if (status() == Status::unscheduled) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
const_cast<array*>(this)->eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user