mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix compile when compiling multiple lambdas with the same capture
This commit is contained in:
		| @@ -412,7 +412,7 @@ compile_trace( | |||||||
| // Traverses the graph to build a tape and a map of array ids to their parents | // Traverses the graph to build a tape and a map of array ids to their parents | ||||||
| std::pair<std::vector<array>, ParentsMap> compile_dfs( | std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||||
|     const std::vector<array>& inputs, |     const std::vector<array>& inputs, | ||||||
|     const std::vector<array>& outputs, |     std::vector<array>& outputs, | ||||||
|     const std::vector<array>& original_inputs) { |     const std::vector<array>& original_inputs) { | ||||||
|   std::function<void(const array&)> recurse; |   std::function<void(const array&)> recurse; | ||||||
|   std::vector<array> tape; |   std::vector<array> tape; | ||||||
| @@ -457,6 +457,71 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | |||||||
|   for (auto& a : outputs) { |   for (auto& a : outputs) { | ||||||
|     recurse(a); |     recurse(a); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   // Deep copy the tape and parents map | ||||||
|  |   std::vector<array> new_tape; | ||||||
|  |   std::unordered_set<uintptr_t> io_set; | ||||||
|  |   std::unordered_map<uintptr_t, array> old_to_new; | ||||||
|  |   for (auto& o : outputs) { | ||||||
|  |     io_set.insert(o.id()); | ||||||
|  |   } | ||||||
|  |   for (auto& i : inputs) { | ||||||
|  |     io_set.insert(i.id()); | ||||||
|  |     old_to_new.insert({i.id(), i}); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   new_tape.reserve(tape.size()); | ||||||
|  |   for (auto& arr : tape) { | ||||||
|  |     if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) { | ||||||
|  |       old_to_new.insert({arr.id(), arr}); | ||||||
|  |       new_tape.push_back(arr); | ||||||
|  |       continue; | ||||||
|  |     } | ||||||
|  |     std::vector<array> inputs; | ||||||
|  |     inputs.reserve(arr.inputs().size()); | ||||||
|  |     for (auto& i : arr.inputs()) { | ||||||
|  |       inputs.push_back(old_to_new.find(i.id())->second); | ||||||
|  |     } | ||||||
|  |     if (arr.siblings().size() > 0) { | ||||||
|  |       // use make_arrays | ||||||
|  |       std::vector<Dtype> types; | ||||||
|  |       std::vector<Shape> shapes; | ||||||
|  |       auto out = arr.outputs(); | ||||||
|  |       for (auto& o : out) { | ||||||
|  |         types.push_back(o.dtype()); | ||||||
|  |         shapes.push_back(o.shape()); | ||||||
|  |       } | ||||||
|  |       auto as = array::make_arrays( | ||||||
|  |           std::move(shapes), types, arr.primitive_ptr(), std::move(inputs)); | ||||||
|  |       for (int i = 0; i < out.size(); ++i) { | ||||||
|  |         old_to_new.insert({out[i].id(), as[i]}); | ||||||
|  |       } | ||||||
|  |       // TODO maybe need to preserve position of sibling that is in tape | ||||||
|  |       new_tape.push_back(as[0]); | ||||||
|  |     } else { | ||||||
|  |       auto a = array( | ||||||
|  |           arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); | ||||||
|  |       old_to_new.insert({arr.id(), a}); | ||||||
|  |       new_tape.push_back(a); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   for (auto& o : outputs) { | ||||||
|  |     for (auto& i : o.inputs()) { | ||||||
|  |       i = old_to_new.find(i.id())->second; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   tape = std::move(new_tape); | ||||||
|  |  | ||||||
|  |   std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>> | ||||||
|  |       new_parents_map; | ||||||
|  |   for (auto& [id, vec] : parents_map) { | ||||||
|  |     for (auto& [a, _] : vec) { | ||||||
|  |       a = old_to_new.find(a.id())->second; | ||||||
|  |     } | ||||||
|  |     new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec); | ||||||
|  |   } | ||||||
|  |   parents_map = std::move(new_parents_map); | ||||||
|  |  | ||||||
|   return {tape, parents_map}; |   return {tape, parents_map}; | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -47,7 +47,7 @@ using ParentsMap = | |||||||
| // Traverses the graph to build a tape and a map of array ids to their parents | // Traverses the graph to build a tape and a map of array ids to their parents | ||||||
| std::pair<std::vector<array>, ParentsMap> compile_dfs( | std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||||
|     const std::vector<array>& inputs, |     const std::vector<array>& inputs, | ||||||
|     const std::vector<array>& outputs, |     std::vector<array>& outputs, | ||||||
|     const std::vector<array>& original_inputs); |     const std::vector<array>& original_inputs); | ||||||
|  |  | ||||||
| // Simplify the tape. | // Simplify the tape. | ||||||
|   | |||||||
| @@ -194,8 +194,7 @@ auto multi_one(const std::vector<array>&) { | |||||||
| auto multi_two(const std::vector<array>&) { | auto multi_two(const std::vector<array>&) { | ||||||
|   auto a = array(1.0); |   auto a = array(1.0); | ||||||
|   auto b = array(1.0); |   auto b = array(1.0); | ||||||
|   auto c = divmod(a, b); |   return divmod(a, b); | ||||||
|   return std::vector<array>{c}; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| auto multi_three(const std::vector<array>&) { | auto multi_three(const std::vector<array>&) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun