Shared events for synchronization + async eval (#998)

* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
This commit is contained in:
Awni Hannun
2024-04-17 06:16:02 -07:00
committed by GitHub
parent b18468bf81
commit 8a0677d56d
28 changed files with 424 additions and 125 deletions

View File

@@ -93,7 +93,11 @@ void array::detach() {
}
void array::eval() {
if (!is_evaled()) {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this});
}
}
@@ -176,7 +180,7 @@ void array::ArrayDesc::init() {
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype) {
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
@@ -187,6 +191,7 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();

View File

@@ -9,6 +9,7 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core {
@@ -315,9 +316,27 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
};
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
const Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
}
// Mark the array as a tracer array (true) or not.
@@ -370,6 +389,11 @@ class array {
Dtype dtype;
std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
bool is_tracer{false};
@@ -470,10 +494,11 @@ T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (!is_evaled()) {
if (status() == Status::unscheduled) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
const_cast<array*>(this)->eval();
return *data<T>();
}

View File

@@ -26,6 +26,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp

View File

@@ -544,11 +544,12 @@ Device& device(mlx::core::Device) {
return metal_device;
}
std::shared_ptr<void> new_scoped_memory_pool() {
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release();
};
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
return std::unique_ptr<void, std::function<void(void*)>>(
NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {

View File

@@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core {
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
}
void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
}
} // namespace mlx::core

View File

@@ -55,17 +55,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
}
}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p) {
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
std::function<void()> make_task(array arr, bool signal) {
auto task = [arr = std::move(arr), signal]() mutable {
auto pool = new_scoped_memory_pool();
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
// TODO, consider committing the buffer and encoding a wait in the new
// buffer rather than on the task thread
input.event().wait();
}
}
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
@@ -88,13 +91,16 @@ std::function<void()> make_task(
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
if (signal) {
metal::device(s.device).end_encoding(s.index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), p = std::move(p)](
[s, buffers = std::move(buffers), event = arr.event()](
MTL::CommandBuffer* cbuf) {
p->set_value();
scheduler::notify_task_completion(s);
check_error(cbuf);
});

View File

@@ -2,9 +2,7 @@
#pragma once
#include <future>
#include <memory>
#include <vector>
#include "mlx/array.h"
#include "mlx/stream.h"
@@ -12,11 +10,9 @@
namespace mlx::core::metal {
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
std::function<void()> make_task(array arr, bool signal);
} // namespace mlx::core::metal

View File

@@ -2,6 +2,7 @@ target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
)

View File

@@ -0,0 +1,39 @@
// Copyright © 2024 Apple Inc.
#include "mlx/event.h"
#include <condition_variable>
#include <mutex>
namespace mlx::core {
struct EventCounter {
uint64_t value{0};
std::mutex mtx;
std::condition_variable cv;
};
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) { delete static_cast<EventCounter*>(ptr); };
event_ = std::shared_ptr<void>(new EventCounter{}, dtor);
}
void Event::wait() {
auto ec = static_cast<EventCounter*>(raw_event().get());
std::unique_lock<std::mutex> lk(ec->mtx);
if (ec->value >= value()) {
return;
}
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
}
void Event::signal() {
auto ec = static_cast<EventCounter*>(raw_event().get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
}
} // namespace mlx::core

View File

@@ -12,14 +12,12 @@ bool is_available() {
}
void new_stream(Stream) {}
std::shared_ptr<void> new_scoped_memory_pool() {
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
return nullptr;
}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p) {
std::function<void()> make_task(array arr, bool signal) {
throw std::runtime_error(
"[metal::make_task] Cannot make GPU task without metal backend");
}

View File

@@ -352,7 +352,8 @@ void compile_simplify(
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
// Condition for when it's safe to read an array
return a.is_available() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;

56
mlx/event.h Normal file
View File

@@ -0,0 +1,56 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <memory>
#include <stdexcept>
#include "mlx/stream.h"
namespace mlx::core {
class Event {
public:
Event(){};
Event(const Stream& steam);
// Wait for the event to be signaled at its curent value
void wait();
// Signal the event at its current value
void signal();
// Check if the event is valid
bool valid() {
return event_ != nullptr;
};
uint64_t value() {
return value_;
};
void set_value(uint64_t v) {
value_ = v;
};
const Stream& stream() {
if (!valid()) {
throw std::runtime_error(
"[Event::stream] Cannot access stream on invalid event.");
}
return stream_;
};
const std::shared_ptr<void>& raw_event() {
return event_;
};
private:
// Default constructed stream should never be used
// since the event is not yet valid
Stream stream_{0, Device::cpu};
std::shared_ptr<void> event_{nullptr};
uint64_t value_{0};
};
} // namespace mlx::core

View File

@@ -36,29 +36,32 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
std::shared_future<void> async_eval(std::vector<array> outputs) {
static std::shared_future<void> global_synchronizer;
// Catch up with previous async eval if needed
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}
array eval_impl(std::vector<array> outputs, bool async) {
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
// stream events to use for synchronization
std::unordered_map<uint32_t, Event> events;
// Make an effort to choose a good output stream
Stream stream = default_stream(default_device());
for (auto& o : outputs) {
if (!o.is_evaled() && o.has_primitive()) {
if (o.status() == array::Status::unscheduled && o.has_primitive()) {
stream = o.primitive().stream();
break;
}
}
std::unordered_set<uintptr_t> needs_signal;
auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
needs_signal.insert(synchronizer.id());
// Make an event for the synchronizer stream
events.emplace(stream.index, Event{stream});
{
std::unordered_set<std::uintptr_t> cache;
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
dfs.emplace(synchronizer, 0);
while (!dfs.empty()) {
@@ -67,16 +70,23 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
if (idx < a.inputs().size()) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
if (!in.is_evaled()) {
// Ignore arrays already scheduled
if (in.status() == array::Status::scheduled) {
continue;
}
if (!in.is_available()) {
if (async && in.is_tracer()) {
throw std::invalid_argument(
"[async_eval] Not allowed inside a graph transformation.");
}
if (!in.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
}
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.output(0).id(), std::shared_future<void>{}});
needs_signal.insert(in.id());
}
}
@@ -91,52 +101,54 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
}
// All inputs are done being processed, process this array
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (a.is_available() && !a.is_tracer() && a.has_primitive()) {
// If the array is evaluated and is no longer a tracer, detach it
a.detach();
} else if (a.status() == array::Status::unscheduled) {
tape.push(a);
// Lookup corresponding event and increment counter
auto& stream = a.primitive().stream();
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
a.attach_event(e->second);
for (auto& s : a.siblings()) {
s.attach_event(e->second);
}
}
dfs.pop();
}
}
deps.insert({synchronizer.id(), std::shared_future<void>{}});
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
if (arr.is_evaled()) {
if (!arr.is_tracer() && arr.has_primitive()) {
arr.detach();
}
continue;
// Set the status of the array and siblings.
auto status = async ? array::Status::scheduled : array::Status::available;
arr.set_status(status);
for (auto& s : arr.siblings()) {
s.set_status(status);
}
auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) {
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
arr_deps.push_back(it->second);
}
}
std::shared_ptr<std::promise<void>> p;
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
p = std::make_shared<std::promise<void>>();
ps.push_back(p);
it->second = p->get_future().share();
}
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
throw std::runtime_error("Metal GPU is not available.");
}
scheduler::enqueue(
stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
} else {
auto task = [arr,
stream,
deps = std::move(arr_deps),
p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
input.event().wait();
}
}
scheduler::notify_new_task(stream);
auto outputs = arr.outputs();
@@ -144,20 +156,24 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
p->set_value();
if (signal) {
arr.event().signal();
}
scheduler::notify_task_completion(stream);
};
scheduler::enqueue(stream, std::move(task));
}
}
global_synchronizer = std::move(deps[synchronizer.id()]);
return global_synchronizer;
return synchronizer;
}
void async_eval(std::vector<array> outputs) {
eval_impl(std::move(outputs), true);
}
void eval(std::vector<array> outputs) {
async_eval(std::move(outputs)).wait();
eval_impl(std::move(outputs), false).event().wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@@ -2,12 +2,11 @@
#pragma once
#include <future>
#include "mlx/array.h"
namespace mlx::core {
std::shared_future<void> async_eval(std::vector<array> outputs);
void async_eval(std::vector<array> outputs);
void eval(std::vector<array> outputs);

View File

@@ -264,9 +264,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
}
std::ostream& operator<<(std::ostream& os, array a) {
if (!a.is_evaled()) {
a.eval();
}
a.eval();
switch (a.dtype()) {
case bool_:
print_array<bool>(os, a);