[Metal] Release metal events (#2412)

* release metal events

* fix

* fix
This commit is contained in:
Awni Hannun 2025-07-23 19:53:42 -07:00 committed by GitHub
parent d1f4d291e8
commit 4e504039f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 17 deletions

View File

@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
if (event_ == nullptr) {
throw std::runtime_error(
"[Event::Event] Failed to create Metal shared event.");
}
}
void Event::wait() {

View File

@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Stream events for synchronization after eval
std::unordered_map<uint32_t, Event> events;
events.emplace(stream.index, Event{stream});
{
auto e = Event{stream};
e.set_value(1);
synchronizer.attach_event(e);
events.emplace(stream.index, std::move(e));
}
{
// Record the degree of each input
@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
std::unordered_set<int> open_streams;
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();
auto stream = arr.primitive().stream();
open_streams.insert(stream.index);
// Lookup corresponding event
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
if (async) {
// Lookup corresponding event
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
}
for (auto& in : arr.inputs()) {
@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
(get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
gpu::finalize(e.stream());
for (auto i : open_streams) {
auto s = get_stream(i);
if (s.device == Device::gpu) {
gpu::finalize(s);
}
}
scheduler::wait_for_one();
@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
// Signal the event in its stream
for (auto& [_, e] : events) {
auto s = e.stream();
e.signal(s);
for (auto i : open_streams) {
auto s = get_stream(i);
if (auto e = events.find(i); e != events.end()) {
e->second.signal(s);
}
if (s.device == Device::gpu) {
gpu::finalize(s);
}
@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
return;
}
eval_impl(std::move(outputs), false).event().wait();
eval_impl(std::move(outputs), false).wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(