mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0}; | ||||
| int detail::RetainGraph::tracing_counter{0}; | ||||
|  | ||||
| array eval_impl(std::vector<array> outputs, bool async) { | ||||
|   std::queue<array> tape; | ||||
|   std::vector<array> tape; | ||||
|  | ||||
|   // stream events to use for synchronization | ||||
|   std::unordered_map<uint32_t, Event> events; | ||||
| @@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|   events.emplace(stream.index, Event{stream}); | ||||
|  | ||||
|   { | ||||
|     std::unordered_set<std::uintptr_t> cache; | ||||
|     // Record the degree of each input | ||||
|     std::unordered_map<std::uintptr_t, int> cache; | ||||
|  | ||||
|     std::stack<std::pair<std::reference_wrapper<array>, int>> dfs; | ||||
|     dfs.emplace(synchronizer, 0); | ||||
|     while (!dfs.empty()) { | ||||
| @@ -104,42 +106,75 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         if (cache.find(in.id()) == cache.end()) { | ||||
|         // All siblings have the same degree | ||||
|         auto cache_it = cache.find(in.id()); | ||||
|         if (cache_it == cache.end()) { | ||||
|           dfs.emplace(in, 0); | ||||
|           cache.insert(in.id()); | ||||
|           cache.insert({in.id(), 1}); | ||||
|           for (auto& s : in.siblings()) { | ||||
|             cache.insert(s.id()); | ||||
|             cache.insert({s.id(), 1}); | ||||
|           } | ||||
|         } else { | ||||
|           cache_it->second++; | ||||
|           for (auto& s : in.siblings()) { | ||||
|             cache[s.id()]++; | ||||
|           } | ||||
|         } | ||||
|         continue; | ||||
|       } | ||||
|  | ||||
|       // All inputs are done being processed, process this array | ||||
|       if ((a.status() != array::Status::unscheduled) && !a.is_tracer() && | ||||
|           a.has_primitive()) { | ||||
|         // If the array is evaluated and is no longer a tracer, detach it | ||||
|         a.detach(); | ||||
|       } else if (a.status() == array::Status::unscheduled) { | ||||
|         tape.push(a); | ||||
|         // Lookup corresponding event and increment counter | ||||
|         auto& stream = a.primitive().stream(); | ||||
|         auto e = events.find(stream.index); | ||||
|         if (e == events.end()) { | ||||
|           e = events.emplace(stream.index, Event{stream}).first; | ||||
|         } | ||||
|         e->second.set_value(e->second.value() + 1); | ||||
|         a.attach_event(e->second); | ||||
|         for (auto& s : a.siblings()) { | ||||
|           s.attach_event(e->second); | ||||
|         } | ||||
|       } | ||||
|       dfs.pop(); | ||||
|     } | ||||
|  | ||||
|     // Build the tape in BFS order | ||||
|     tape.push_back(synchronizer); | ||||
|     for (int i = 0; !cache.empty() && i < tape.size(); ++i) { | ||||
|       auto& a = tape[i]; | ||||
|       for (auto& in : a.inputs()) { | ||||
|         if (in.status() != array::Status::unscheduled) { | ||||
|           continue; | ||||
|         } | ||||
|         auto it = cache.find(in.id()); | ||||
|         it->second -= 1; | ||||
|  | ||||
|         if (it->second != 0) { | ||||
|           for (auto& s : in.siblings()) { | ||||
|             cache[s.id()] -= 1; | ||||
|           } | ||||
|           continue; | ||||
|         } | ||||
|  | ||||
|         // Remove input and siblings from cache | ||||
|         cache.erase(it); | ||||
|         for (auto& s : in.siblings()) { | ||||
|           cache.erase(s.id()); | ||||
|         } | ||||
|  | ||||
|         tape.push_back(in); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   while (!tape.empty()) { | ||||
|     auto arr = std::move(tape.front()); | ||||
|     tape.pop(); | ||||
|     auto arr = std::move(tape.back()); | ||||
|     tape.pop_back(); | ||||
|  | ||||
|     auto stream = arr.primitive().stream(); | ||||
|  | ||||
|     // Lookup corresponding event and increment counter | ||||
|     auto e = events.find(stream.index); | ||||
|     if (e == events.end()) { | ||||
|       e = events.emplace(stream.index, Event{stream}).first; | ||||
|     } | ||||
|     e->second.set_value(e->second.value() + 1); | ||||
|     arr.attach_event(e->second); | ||||
|     for (auto& s : arr.siblings()) { | ||||
|       s.attach_event(e->second); | ||||
|     } | ||||
|  | ||||
|     // Set the status of the array and siblings. | ||||
|     arr.set_status(array::Status::scheduled); | ||||
| @@ -147,7 +182,6 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|       s.set_status(array::Status::scheduled); | ||||
|     } | ||||
|  | ||||
|     auto stream = arr.primitive().stream(); | ||||
|     std::vector<std::shared_future<void>> arr_deps; | ||||
|     bool signal = needs_signal.find(arr.id()) != needs_signal.end(); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun