mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Limit compile buffers (#1887)
* limit compile buffers * maybe not flaky test
This commit is contained in:
		| @@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize(mx.default_stream(mx.default_device())) | ||||
|  | ||||
|         mx.metal.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|   | ||||
| @@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         out = fun(*inputs) | ||||
|         self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) | ||||
|  | ||||
|         @mx.compile | ||||
|         def fun(arrs): | ||||
|             for _ in range(6): | ||||
|                 arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] | ||||
|             return arrs[0] | ||||
|  | ||||
|         arrs = [mx.array([1.0, 2.0]) for _ in range(64)] | ||||
|         out = fun(arrs) | ||||
|         self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0]))) | ||||
|  | ||||
|     def test_compile_many_outputs(self): | ||||
|  | ||||
|         @mx.compile | ||||
|         def fun(arr): | ||||
|             arrs = [arr] * 64 | ||||
|             first_arrs = None | ||||
|             for _ in range(6): | ||||
|                 arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] | ||||
|                 if first_arrs is None: | ||||
|                     first_arrs = arrs | ||||
|             return arrs[0], first_arrs | ||||
|  | ||||
|         out = fun(mx.array([1.0, 2.0])) | ||||
|         self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0]))) | ||||
|  | ||||
|     def test_shapeless_compile_matmul(self): | ||||
|         a = mx.array([0.0, 1.0, 2.0]) | ||||
|         b = mx.array([0.0, 1.0, 2.0]) | ||||
|   | ||||
| @@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|         mx.eval(x) | ||||
|         save_file = os.path.join(self.test_dir, "donation.npy") | ||||
|         mx.save(save_file, x) | ||||
|         mx.synchronize(mx.default_stream(mx.default_device())) | ||||
|  | ||||
|         mx.metal.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun