mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	add test
This commit is contained in:
		| @@ -294,6 +294,11 @@ class array { | |||||||
|     return array_desc_->siblings; |     return array_desc_->siblings; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   /** The array's position in the sibling list. */ | ||||||
|  |   int sibling_position() const { | ||||||
|  |     return array_desc_->position; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   void set_siblings(std::vector<array> siblings, uint16_t position) { |   void set_siblings(std::vector<array> siblings, uint16_t position) { | ||||||
|     array_desc_->siblings = std::move(siblings); |     array_desc_->siblings = std::move(siblings); | ||||||
|     array_desc_->position = position; |     array_desc_->position = position; | ||||||
|   | |||||||
| @@ -414,55 +414,58 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | |||||||
|     const std::vector<array>& inputs, |     const std::vector<array>& inputs, | ||||||
|     std::vector<array>& outputs, |     std::vector<array>& outputs, | ||||||
|     const std::vector<array>& original_inputs) { |     const std::vector<array>& original_inputs) { | ||||||
|   std::function<void(const array&)> recurse; |  | ||||||
|   std::vector<array> tape; |   std::vector<array> tape; | ||||||
|   std::unordered_set<std::uintptr_t> input_set; |  | ||||||
|   std::unordered_set<std::uintptr_t> original_input_set; |  | ||||||
|   std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>> |   std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>> | ||||||
|       parents_map; |       parents_map; | ||||||
|   for (int i = 0; i < inputs.size(); ++i) { |   { | ||||||
|     input_set.insert(inputs[i].id()); |     std::function<void(const array&)> recurse; | ||||||
|     original_input_set.insert(original_inputs[i].id()); |     std::unordered_set<std::uintptr_t> input_set; | ||||||
|   } |     std::unordered_set<std::uintptr_t> original_input_set; | ||||||
|  |     for (int i = 0; i < inputs.size(); ++i) { | ||||||
|  |       input_set.insert(inputs[i].id()); | ||||||
|  |       original_input_set.insert(original_inputs[i].id()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|   // DFS the graph to build the tape, and log parents and scalars |     // DFS the graph to build the tape, and log parents and scalars | ||||||
|   std::unordered_set<std::uintptr_t> cache; |     std::unordered_set<std::uintptr_t> cache; | ||||||
|   recurse = [&](const array& a) { |     recurse = [&](const array& a) { | ||||||
|     auto id = a.id(); |       auto id = a.id(); | ||||||
|     if (original_input_set.find(id) != original_input_set.end()) { |       if (original_input_set.find(id) != original_input_set.end()) { | ||||||
|       throw std::invalid_argument( |         throw std::invalid_argument( | ||||||
|           "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); |             "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); | ||||||
|     } |       } | ||||||
|     if (cache.find(id) != cache.end()) { |       if (cache.find(id) != cache.end()) { | ||||||
|       return; |         return; | ||||||
|     } |       } | ||||||
|     for (int i = 0; i < a.inputs().size(); i++) { |       for (int i = 0; i < a.inputs().size(); i++) { | ||||||
|       auto& in = a.inputs()[i]; |         auto& in = a.inputs()[i]; | ||||||
|       parents_map[in.id()].push_back({a, i}); |         parents_map[in.id()].push_back({a, i}); | ||||||
|  |         for (auto& s : a.siblings()) { | ||||||
|  |           parents_map[in.id()].push_back({s, i}); | ||||||
|  |         } | ||||||
|  |         // Don't recurse on inputs (but add them to the tape for the purpose | ||||||
|  |         // of future optimizations) | ||||||
|  |         if (input_set.find(a.id()) == input_set.end()) { | ||||||
|  |           recurse(in); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       cache.insert(id); | ||||||
|       for (auto& s : a.siblings()) { |       for (auto& s : a.siblings()) { | ||||||
|         parents_map[in.id()].push_back({s, i}); |         cache.insert(s.id()); | ||||||
|       } |  | ||||||
|       // Don't recurse on inputs (but add them to the tape for the purpose |  | ||||||
|       // of future optimizations) |  | ||||||
|       if (input_set.find(a.id()) == input_set.end()) { |  | ||||||
|         recurse(in); |  | ||||||
|       } |       } | ||||||
|  |       tape.push_back(a); | ||||||
|  |     }; | ||||||
|  |     for (auto& a : outputs) { | ||||||
|  |       recurse(a); | ||||||
|     } |     } | ||||||
|     cache.insert(id); |  | ||||||
|     for (auto& s : a.siblings()) { |  | ||||||
|       cache.insert(s.id()); |  | ||||||
|     } |  | ||||||
|     tape.push_back(a); |  | ||||||
|   }; |  | ||||||
|   for (auto& a : outputs) { |  | ||||||
|     recurse(a); |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Deep copy the tape and parents map |   // Deep copy the tape and parents map while preserving inputs and outputs | ||||||
|   std::vector<array> new_tape; |   std::vector<array> new_tape; | ||||||
|   std::unordered_set<uintptr_t> io_set; |   std::unordered_set<uintptr_t> io_set; | ||||||
|   std::unordered_map<uintptr_t, array> old_to_new; |   std::unordered_map<uintptr_t, array> old_to_new; | ||||||
|   for (auto& o : outputs) { |   for (auto& o : outputs) { | ||||||
|  |     old_to_new.insert({o.id(), o}); | ||||||
|     io_set.insert(o.id()); |     io_set.insert(o.id()); | ||||||
|   } |   } | ||||||
|   for (auto& i : inputs) { |   for (auto& i : inputs) { | ||||||
| @@ -483,7 +486,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | |||||||
|       inputs.push_back(old_to_new.find(i.id())->second); |       inputs.push_back(old_to_new.find(i.id())->second); | ||||||
|     } |     } | ||||||
|     if (arr.siblings().size() > 0) { |     if (arr.siblings().size() > 0) { | ||||||
|       // use make_arrays |  | ||||||
|       std::vector<Dtype> types; |       std::vector<Dtype> types; | ||||||
|       std::vector<Shape> shapes; |       std::vector<Shape> shapes; | ||||||
|       auto out = arr.outputs(); |       auto out = arr.outputs(); | ||||||
| @@ -496,8 +498,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | |||||||
|       for (int i = 0; i < out.size(); ++i) { |       for (int i = 0; i < out.size(); ++i) { | ||||||
|         old_to_new.insert({out[i].id(), as[i]}); |         old_to_new.insert({out[i].id(), as[i]}); | ||||||
|       } |       } | ||||||
|       // TODO maybe need to preserve position of sibling that is in tape |       new_tape.push_back(as[arr.sibling_position()]); | ||||||
|       new_tape.push_back(as[0]); |  | ||||||
|     } else { |     } else { | ||||||
|       auto a = array( |       auto a = array( | ||||||
|           arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); |           arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); | ||||||
| @@ -505,7 +506,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs( | |||||||
|       new_tape.push_back(a); |       new_tape.push_back(a); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |   io_set.clear(); | ||||||
|   for (auto& o : outputs) { |   for (auto& o : outputs) { | ||||||
|  |     if (!(io_set.insert(o.id()).second)) { | ||||||
|  |       continue; | ||||||
|  |     } | ||||||
|     for (auto& i : o.inputs()) { |     for (auto& i : o.inputs()) { | ||||||
|       i = old_to_new.find(i.id())->second; |       i = old_to_new.find(i.id())->second; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase): | |||||||
|         a = fun2(mx.array(-1.0)) |         a = fun2(mx.array(-1.0)) | ||||||
|         self.assertEqual(a.item(), 1.0) |         self.assertEqual(a.item(), 1.0) | ||||||
|  |  | ||||||
|  |     def test_multiple_compile_same_capture(self): | ||||||
|  |         def fun(do_compile): | ||||||
|  |             t = mx.ones((10,)) | ||||||
|  |             u = (1.0 - t) * 0.0 + t * 3.0 | ||||||
|  |  | ||||||
|  |             o = mx.ones((6,)) | ||||||
|  |             b = o[:, None] * u | ||||||
|  |  | ||||||
|  |             c = b * mx.ones_like(u) | ||||||
|  |  | ||||||
|  |             a = mx.ones((6,)) | ||||||
|  |             if do_compile: | ||||||
|  |                 d = mx.compile(lambda x: x @ b)(a) | ||||||
|  |                 e = mx.compile(lambda x: x @ c.T)(d) | ||||||
|  |             else: | ||||||
|  |                 d = a @ b | ||||||
|  |                 e = d @ c.T | ||||||
|  |             return e | ||||||
|  |  | ||||||
|  |         out = fun(True) | ||||||
|  |         mx.eval(out) | ||||||
|  |         expected = fun(False) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     mlx_tests.MLXTestRunner() |     mlx_tests.MLXTestRunner() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun