mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix compile with non standard types (#745)
* refactor tree utils * fix compile + tree code refactor * Add an extra test * add a few missing activations to docs * hash structure * Encode the full argument structure --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         z = fun(mx.array(1), "two") | ||||
|         self.assertEqual(z.item(), 3) | ||||
|  | ||||
|         # Test nested constant | ||||
|         @partial(mx.compile) | ||||
|         def fun(x, y): | ||||
|             if y[0][0] == 1: | ||||
|                 return x + 1 | ||||
|             else: | ||||
|                 return x + 2 | ||||
|  | ||||
|         z = fun(mx.array(1), [[1]]) | ||||
|         self.assertEqual(z.item(), 2) | ||||
|  | ||||
|         z = fun(mx.array(1), [[0]]) | ||||
|         self.assertEqual(z.item(), 3) | ||||
|  | ||||
|         @partial(mx.compile) | ||||
|         def fun(x, a, b): | ||||
|             for ai in a: | ||||
|                 for bi in b: | ||||
|                     x = bi * x + ai | ||||
|             return x | ||||
|  | ||||
|         z = fun(mx.array(1), [1, 1], [2]) | ||||
|         self.assertEqual(z.item(), 7) | ||||
|  | ||||
|         z = fun(mx.array(1), [1], [1, 2]) | ||||
|         self.assertEqual(z.item(), 5) | ||||
|  | ||||
|         counter = [0] | ||||
|  | ||||
|         @partial(mx.compile) | ||||
|         def fun(x, y): | ||||
|             counter[0] += 1 | ||||
|             return x + y | ||||
|  | ||||
|         z = fun(mx.array(1), 1) | ||||
|         self.assertEqual(z.item(), 2) | ||||
|  | ||||
|         z = fun(1, mx.array(1)) | ||||
|         self.assertEqual(z.item(), 2) | ||||
|  | ||||
|         self.assertEqual(counter[0], 2) | ||||
|  | ||||
|     def test_compile_inf(self): | ||||
|  | ||||
|         @mx.compile | ||||
| @@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         out = fun(mx.array([0.0])) | ||||
|         self.assertEqual(out.item(), False) | ||||
|  | ||||
|     def test_unsupported_input_types(self): | ||||
|  | ||||
|         class MyClass: | ||||
|             value = 1 | ||||
|  | ||||
|         @mx.compile | ||||
|         def fun(x, y): | ||||
|             return x + y.value | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = fun(mx.array(0.0), MyClass()) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = fun(mx.array(0.0), y=MyClass()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun