mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	add test
This commit is contained in:
		| @@ -294,6 +294,11 @@ class array { | ||||
|     return array_desc_->siblings; | ||||
|   } | ||||
|  | ||||
|   /** The array's position in the sibling list. */ | ||||
|   int sibling_position() const { | ||||
|     return array_desc_->position; | ||||
|   } | ||||
|  | ||||
|   void set_siblings(std::vector<array> siblings, uint16_t position) { | ||||
|     array_desc_->siblings = std::move(siblings); | ||||
|     array_desc_->position = position; | ||||
|   | ||||
| @@ -414,55 +414,58 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs, | ||||
|     const std::vector<array>& original_inputs) { | ||||
|   std::function<void(const array&)> recurse; | ||||
|   std::vector<array> tape; | ||||
|   std::unordered_set<std::uintptr_t> input_set; | ||||
|   std::unordered_set<std::uintptr_t> original_input_set; | ||||
|   std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>> | ||||
|       parents_map; | ||||
|   for (int i = 0; i < inputs.size(); ++i) { | ||||
|     input_set.insert(inputs[i].id()); | ||||
|     original_input_set.insert(original_inputs[i].id()); | ||||
|   } | ||||
|   { | ||||
|     std::function<void(const array&)> recurse; | ||||
|     std::unordered_set<std::uintptr_t> input_set; | ||||
|     std::unordered_set<std::uintptr_t> original_input_set; | ||||
|     for (int i = 0; i < inputs.size(); ++i) { | ||||
|       input_set.insert(inputs[i].id()); | ||||
|       original_input_set.insert(original_inputs[i].id()); | ||||
|     } | ||||
|  | ||||
|   // DFS the graph to build the tape, and log parents and scalars | ||||
|   std::unordered_set<std::uintptr_t> cache; | ||||
|   recurse = [&](const array& a) { | ||||
|     auto id = a.id(); | ||||
|     if (original_input_set.find(id) != original_input_set.end()) { | ||||
|       throw std::invalid_argument( | ||||
|           "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); | ||||
|     } | ||||
|     if (cache.find(id) != cache.end()) { | ||||
|       return; | ||||
|     } | ||||
|     for (int i = 0; i < a.inputs().size(); i++) { | ||||
|       auto& in = a.inputs()[i]; | ||||
|       parents_map[in.id()].push_back({a, i}); | ||||
|     // DFS the graph to build the tape, and log parents and scalars | ||||
|     std::unordered_set<std::uintptr_t> cache; | ||||
|     recurse = [&](const array& a) { | ||||
|       auto id = a.id(); | ||||
|       if (original_input_set.find(id) != original_input_set.end()) { | ||||
|         throw std::invalid_argument( | ||||
|             "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); | ||||
|       } | ||||
|       if (cache.find(id) != cache.end()) { | ||||
|         return; | ||||
|       } | ||||
|       for (int i = 0; i < a.inputs().size(); i++) { | ||||
|         auto& in = a.inputs()[i]; | ||||
|         parents_map[in.id()].push_back({a, i}); | ||||
|         for (auto& s : a.siblings()) { | ||||
|           parents_map[in.id()].push_back({s, i}); | ||||
|         } | ||||
|         // Don't recurse on inputs (but add them to the tape for the purpose | ||||
|         // of future optimizations) | ||||
|         if (input_set.find(a.id()) == input_set.end()) { | ||||
|           recurse(in); | ||||
|         } | ||||
|       } | ||||
|       cache.insert(id); | ||||
|       for (auto& s : a.siblings()) { | ||||
|         parents_map[in.id()].push_back({s, i}); | ||||
|       } | ||||
|       // Don't recurse on inputs (but add them to the tape for the purpose | ||||
|       // of future optimizations) | ||||
|       if (input_set.find(a.id()) == input_set.end()) { | ||||
|         recurse(in); | ||||
|         cache.insert(s.id()); | ||||
|       } | ||||
|       tape.push_back(a); | ||||
|     }; | ||||
|     for (auto& a : outputs) { | ||||
|       recurse(a); | ||||
|     } | ||||
|     cache.insert(id); | ||||
|     for (auto& s : a.siblings()) { | ||||
|       cache.insert(s.id()); | ||||
|     } | ||||
|     tape.push_back(a); | ||||
|   }; | ||||
|   for (auto& a : outputs) { | ||||
|     recurse(a); | ||||
|   } | ||||
|  | ||||
|   // Deep copy the tape and parents map | ||||
|   // Deep copy the tape and parents map while preserving inputs and outputs | ||||
|   std::vector<array> new_tape; | ||||
|   std::unordered_set<uintptr_t> io_set; | ||||
|   std::unordered_map<uintptr_t, array> old_to_new; | ||||
|   for (auto& o : outputs) { | ||||
|     old_to_new.insert({o.id(), o}); | ||||
|     io_set.insert(o.id()); | ||||
|   } | ||||
|   for (auto& i : inputs) { | ||||
| @@ -483,7 +486,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||
|       inputs.push_back(old_to_new.find(i.id())->second); | ||||
|     } | ||||
|     if (arr.siblings().size() > 0) { | ||||
|       // use make_arrays | ||||
|       std::vector<Dtype> types; | ||||
|       std::vector<Shape> shapes; | ||||
|       auto out = arr.outputs(); | ||||
| @@ -496,8 +498,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||
|       for (int i = 0; i < out.size(); ++i) { | ||||
|         old_to_new.insert({out[i].id(), as[i]}); | ||||
|       } | ||||
|       // TODO maybe need to preserve position of sibling that is in tape | ||||
|       new_tape.push_back(as[0]); | ||||
|       new_tape.push_back(as[arr.sibling_position()]); | ||||
|     } else { | ||||
|       auto a = array( | ||||
|           arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); | ||||
| @@ -505,7 +506,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | ||||
|       new_tape.push_back(a); | ||||
|     } | ||||
|   } | ||||
|   io_set.clear(); | ||||
|   for (auto& o : outputs) { | ||||
|     if (!(io_set.insert(o.id()).second)) { | ||||
|       continue; | ||||
|     } | ||||
|     for (auto& i : o.inputs()) { | ||||
|       i = old_to_new.find(i.id())->second; | ||||
|     } | ||||
|   | ||||
| @@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         a = fun2(mx.array(-1.0)) | ||||
|         self.assertEqual(a.item(), 1.0) | ||||
|  | ||||
|     def test_multiple_compile_same_capture(self): | ||||
|         def fun(do_compile): | ||||
|             t = mx.ones((10,)) | ||||
|             u = (1.0 - t) * 0.0 + t * 3.0 | ||||
|  | ||||
|             o = mx.ones((6,)) | ||||
|             b = o[:, None] * u | ||||
|  | ||||
|             c = b * mx.ones_like(u) | ||||
|  | ||||
|             a = mx.ones((6,)) | ||||
|             if do_compile: | ||||
|                 d = mx.compile(lambda x: x @ b)(a) | ||||
|                 e = mx.compile(lambda x: x @ c.T)(d) | ||||
|             else: | ||||
|                 d = a @ b | ||||
|                 e = d @ c.T | ||||
|             return e | ||||
|  | ||||
|         out = fun(True) | ||||
|         mx.eval(out) | ||||
|         expected = fun(False) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     mlx_tests.MLXTestRunner() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun