mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
[Metal] Release metal events (#2412)
* release metal events * fix * fix
This commit is contained in:
parent
d1f4d291e8
commit
4e504039f5
@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
|
|||||||
auto p = metal::new_scoped_memory_pool();
|
auto p = metal::new_scoped_memory_pool();
|
||||||
event_ = std::shared_ptr<void>(
|
event_ = std::shared_ptr<void>(
|
||||||
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
||||||
|
if (event_ == nullptr) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Event::Event] Failed to create Metal shared event.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait() {
|
void Event::wait() {
|
||||||
|
@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
|
|
||||||
// Stream events for synchronization after eval
|
// Stream events for synchronization after eval
|
||||||
std::unordered_map<uint32_t, Event> events;
|
std::unordered_map<uint32_t, Event> events;
|
||||||
events.emplace(stream.index, Event{stream});
|
{
|
||||||
|
auto e = Event{stream};
|
||||||
|
e.set_value(1);
|
||||||
|
synchronizer.attach_event(e);
|
||||||
|
events.emplace(stream.index, std::move(e));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Record the degree of each input
|
// Record the degree of each input
|
||||||
@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_set<int> open_streams;
|
||||||
|
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.back());
|
auto arr = std::move(tape.back());
|
||||||
tape.pop_back();
|
tape.pop_back();
|
||||||
|
|
||||||
auto stream = arr.primitive().stream();
|
auto stream = arr.primitive().stream();
|
||||||
|
open_streams.insert(stream.index);
|
||||||
|
|
||||||
// Lookup corresponding event
|
if (async) {
|
||||||
auto e = events.find(stream.index);
|
// Lookup corresponding event
|
||||||
if (e == events.end()) {
|
auto e = events.find(stream.index);
|
||||||
e = events.emplace(stream.index, Event{stream}).first;
|
if (e == events.end()) {
|
||||||
}
|
e = events.emplace(stream.index, Event{stream}).first;
|
||||||
e->second.set_value(1);
|
}
|
||||||
arr.attach_event(e->second);
|
e->second.set_value(1);
|
||||||
for (auto& s : arr.siblings()) {
|
arr.attach_event(e->second);
|
||||||
s.attach_event(e->second);
|
for (auto& s : arr.siblings()) {
|
||||||
|
s.attach_event(e->second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
(get_active_memory() > get_memory_limit() &&
|
(get_active_memory() > get_memory_limit() &&
|
||||||
scheduler::n_active_tasks() > 0)) {
|
scheduler::n_active_tasks() > 0)) {
|
||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
if (e.stream().device == Device::gpu) {
|
auto s = get_stream(i);
|
||||||
gpu::finalize(e.stream());
|
if (s.device == Device::gpu) {
|
||||||
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Signal the event in its stream
|
// Signal the event in its stream
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
auto s = e.stream();
|
auto s = get_stream(i);
|
||||||
e.signal(s);
|
if (auto e = events.find(i); e != events.end()) {
|
||||||
|
e->second.signal(s);
|
||||||
|
}
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
gpu::finalize(s);
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_impl(std::move(outputs), false).event().wait();
|
eval_impl(std::move(outputs), false).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
Loading…
Reference in New Issue
Block a user