mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[Metal] Release metal events (#2412)
* release metal events * fix * fix
This commit is contained in:
		| @@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) { | ||||
|   auto p = metal::new_scoped_memory_pool(); | ||||
|   event_ = std::shared_ptr<void>( | ||||
|       metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); | ||||
|   if (event_ == nullptr) { | ||||
|     throw std::runtime_error( | ||||
|         "[Event::Event] Failed to create Metal shared event."); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void Event::wait() { | ||||
|   | ||||
| @@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|  | ||||
|   // Stream events for synchronization after eval | ||||
|   std::unordered_map<uint32_t, Event> events; | ||||
|   events.emplace(stream.index, Event{stream}); | ||||
|   { | ||||
|     auto e = Event{stream}; | ||||
|     e.set_value(1); | ||||
|     synchronizer.attach_event(e); | ||||
|     events.emplace(stream.index, std::move(e)); | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     // Record the degree of each input | ||||
| @@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::unordered_set<int> open_streams; | ||||
|  | ||||
|   while (!tape.empty()) { | ||||
|     auto arr = std::move(tape.back()); | ||||
|     tape.pop_back(); | ||||
|  | ||||
|     auto stream = arr.primitive().stream(); | ||||
|     open_streams.insert(stream.index); | ||||
|  | ||||
|     // Lookup corresponding event | ||||
|     auto e = events.find(stream.index); | ||||
|     if (e == events.end()) { | ||||
|       e = events.emplace(stream.index, Event{stream}).first; | ||||
|     } | ||||
|     e->second.set_value(1); | ||||
|     arr.attach_event(e->second); | ||||
|     for (auto& s : arr.siblings()) { | ||||
|       s.attach_event(e->second); | ||||
|     if (async) { | ||||
|       // Lookup corresponding event | ||||
|       auto e = events.find(stream.index); | ||||
|       if (e == events.end()) { | ||||
|         e = events.emplace(stream.index, Event{stream}).first; | ||||
|       } | ||||
|       e->second.set_value(1); | ||||
|       arr.attach_event(e->second); | ||||
|       for (auto& s : arr.siblings()) { | ||||
|         s.attach_event(e->second); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     for (auto& in : arr.inputs()) { | ||||
| @@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|         (get_active_memory() > get_memory_limit() && | ||||
|          scheduler::n_active_tasks() > 0)) { | ||||
|       // Commit any open streams | ||||
|       for (auto& [_, e] : events) { | ||||
|         if (e.stream().device == Device::gpu) { | ||||
|           gpu::finalize(e.stream()); | ||||
|       for (auto i : open_streams) { | ||||
|         auto s = get_stream(i); | ||||
|         if (s.device == Device::gpu) { | ||||
|           gpu::finalize(s); | ||||
|         } | ||||
|       } | ||||
|       scheduler::wait_for_one(); | ||||
| @@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|   } | ||||
|  | ||||
|   // Signal the event in its stream | ||||
|   for (auto& [_, e] : events) { | ||||
|     auto s = e.stream(); | ||||
|     e.signal(s); | ||||
|   for (auto i : open_streams) { | ||||
|     auto s = get_stream(i); | ||||
|     if (auto e = events.find(i); e != events.end()) { | ||||
|       e->second.signal(s); | ||||
|     } | ||||
|     if (s.device == Device::gpu) { | ||||
|       gpu::finalize(s); | ||||
|     } | ||||
| @@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   eval_impl(std::move(outputs), false).event().wait(); | ||||
|   eval_impl(std::move(outputs), false).wait(); | ||||
| } | ||||
|  | ||||
| std::pair<std::vector<array>, std::vector<array>> vjp( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun