mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 13:11:26 +08:00
[Metal] Release metal events (#2412)
* release metal events * fix * fix
This commit is contained in:
parent
d1f4d291e8
commit
4e504039f5
@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
event_ = std::shared_ptr<void>(
|
||||
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
||||
if (event_ == nullptr) {
|
||||
throw std::runtime_error(
|
||||
"[Event::Event] Failed to create Metal shared event.");
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
|
@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
|
||||
// Stream events for synchronization after eval
|
||||
std::unordered_map<uint32_t, Event> events;
|
||||
events.emplace(stream.index, Event{stream});
|
||||
{
|
||||
auto e = Event{stream};
|
||||
e.set_value(1);
|
||||
synchronizer.attach_event(e);
|
||||
events.emplace(stream.index, std::move(e));
|
||||
}
|
||||
|
||||
{
|
||||
// Record the degree of each input
|
||||
@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<int> open_streams;
|
||||
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.back());
|
||||
tape.pop_back();
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
open_streams.insert(stream.index);
|
||||
|
||||
// Lookup corresponding event
|
||||
auto e = events.find(stream.index);
|
||||
if (e == events.end()) {
|
||||
e = events.emplace(stream.index, Event{stream}).first;
|
||||
}
|
||||
e->second.set_value(1);
|
||||
arr.attach_event(e->second);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.attach_event(e->second);
|
||||
if (async) {
|
||||
// Lookup corresponding event
|
||||
auto e = events.find(stream.index);
|
||||
if (e == events.end()) {
|
||||
e = events.emplace(stream.index, Event{stream}).first;
|
||||
}
|
||||
e->second.set_value(1);
|
||||
arr.attach_event(e->second);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.attach_event(e->second);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& in : arr.inputs()) {
|
||||
@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
(get_active_memory() > get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0)) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
gpu::finalize(e.stream());
|
||||
for (auto i : open_streams) {
|
||||
auto s = get_stream(i);
|
||||
if (s.device == Device::gpu) {
|
||||
gpu::finalize(s);
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
// Signal the event in its stream
|
||||
for (auto& [_, e] : events) {
|
||||
auto s = e.stream();
|
||||
e.signal(s);
|
||||
for (auto i : open_streams) {
|
||||
auto s = get_stream(i);
|
||||
if (auto e = events.find(i); e != events.end()) {
|
||||
e->second.signal(s);
|
||||
}
|
||||
if (s.device == Device::gpu) {
|
||||
gpu::finalize(s);
|
||||
}
|
||||
@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
|
||||
return;
|
||||
}
|
||||
|
||||
eval_impl(std::move(outputs), false).event().wait();
|
||||
eval_impl(std::move(outputs), false).wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
Loading…
Reference in New Issue
Block a user