mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			201 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			201 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import unittest
 | |
| from functools import partial
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx_tests
 | |
| 
 | |
| 
 | |
| class TestEval(mlx_tests.MLXTestCase):
 | |
|     def test_eval(self):
 | |
|         arrs = [mx.ones((2, 2)) for _ in range(4)]
 | |
|         mx.eval(*arrs)
 | |
|         for x in arrs:
 | |
|             self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
 | |
| 
 | |
|     def test_retain_graph(self):
 | |
|         def fun(x):
 | |
|             y = 3 * x
 | |
|             mx.eval(y)
 | |
|             return 2 * y
 | |
| 
 | |
|         dfun_dx = mx.grad(fun)
 | |
|         y = dfun_dx(mx.array(1.0))
 | |
|         self.assertEqual(y.item(), 6.0)
 | |
| 
 | |
|     def test_eval_mixed(self):
 | |
|         x = mx.array(1) + 1 + 1
 | |
|         y = 0
 | |
|         z = "hello"
 | |
|         state = [x, y, z]
 | |
|         mx.eval(state)
 | |
|         self.assertEqual(x.item(), 3)
 | |
| 
 | |
|     def test_async_eval(self):
 | |
|         x = mx.array(1) + mx.array(1) + mx.array(1)
 | |
|         mx.async_eval(x)
 | |
|         self.assertEqual(x.item(), 3)
 | |
| 
 | |
|         # It should be safe to call eval on the array which has been async
 | |
|         # eval'ed
 | |
|         x = mx.array(1) + mx.array(1) + mx.array(1)
 | |
|         self.assertEqual(x.item(), 3)
 | |
| 
 | |
|         x = mx.array([1, 2, 3])
 | |
|         y = 2 * x
 | |
|         mx.async_eval(y)
 | |
|         z = 2 * y
 | |
|         mx.async_eval(z)
 | |
|         self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6])))
 | |
|         self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
 | |
| 
 | |
|     def test_async_eval_twice(self):
 | |
|         for _ in range(1000):
 | |
|             x = mx.array(1) + mx.array(1) + mx.array(1)
 | |
|             mx.async_eval(x)
 | |
|             y = x + 1
 | |
|             mx.async_eval(y)
 | |
|             self.assertEqual(x.item(), 3)
 | |
|             self.assertEqual(y.item(), 4)
 | |
| 
 | |
|     def test_async_eval_in_trace(self):
 | |
|         def fun(x):
 | |
|             y = x + 1.0
 | |
|             mx.async_eval(y)
 | |
|             return mx.exp(y)
 | |
| 
 | |
|         # Raises
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.grad(fun)(mx.array(1.0))
 | |
| 
 | |
|         # Also raises
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(fun)(mx.ones((2, 2)))
 | |
| 
 | |
|     def test_async_eval_into_eval(self):
 | |
|         x = mx.array(1)
 | |
|         y = x + 1
 | |
|         mx.async_eval(y)
 | |
|         a = y - 10
 | |
|         b = mx.abs(a)
 | |
|         self.assertEqual(b.item(), 8)
 | |
| 
 | |
|     def test_async_eval_into_eval_diff_stream(self):
 | |
|         s = mx.new_stream(mx.cpu)
 | |
|         x = mx.array(0)
 | |
|         y = x - 5
 | |
|         mx.async_eval(y)
 | |
|         z = mx.abs(y, stream=s)
 | |
|         self.assertEqual(z.item(), 5)
 | |
| 
 | |
|     def test_eval_slow_fast_multi_stream(self):
 | |
|         x = mx.ones((8000,))
 | |
|         y = mx.abs(mx.array(-1.0))
 | |
|         for _ in range(20):
 | |
|             x = x + mx.array(1.0)
 | |
|         z = mx.add(x, y, stream=mx.cpu)
 | |
|         self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
 | |
| 
 | |
|         # Switch eval order
 | |
|         x = mx.ones((8000,))
 | |
|         y = mx.abs(mx.array(-1.0))
 | |
|         for _ in range(20):
 | |
|             x = x + mx.array(1.0)
 | |
|         z = mx.add(y, x, stream=mx.cpu)
 | |
|         self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
 | |
| 
 | |
|     def test_multi_output_eval_during_transform(self):
 | |
|         x = mx.random.uniform(shape=(1024,))
 | |
|         y = mx.ones((1024,))
 | |
|         mx.eval(x, y)
 | |
| 
 | |
|         def fn(x):
 | |
|             a, b = mx.divmod(x, x)
 | |
|             mx.eval(a)
 | |
|             return a
 | |
| 
 | |
|         out = mx.vjp(fn, (x,), (y,))
 | |
|         out = mx.vjp(fn, (x,), (y,))
 | |
|         peak_mem = mx.get_peak_memory()
 | |
|         out = mx.vjp(fn, (x,), (y,))
 | |
|         self.assertEqual(peak_mem, mx.get_peak_memory())
 | |
| 
 | |
|     def test_async_eval_with_multiple_streams(self):
 | |
|         x = mx.array([1.0])
 | |
|         y = mx.array([1.0])
 | |
|         a = mx.array([1.0])
 | |
|         b = mx.array([1.0])
 | |
| 
 | |
|         d = mx.default_device()
 | |
|         s2 = mx.new_stream(d)
 | |
| 
 | |
|         for _ in range(50):
 | |
|             for _ in range(20):
 | |
|                 x = x + y
 | |
|             mx.async_eval(x)
 | |
|             mx.eval(a + b)
 | |
| 
 | |
|     def test_donation_for_noops(self):
 | |
|         def fun(x):
 | |
|             s = x.shape
 | |
|             for _ in range(10):
 | |
|                 x = mx.abs(x)
 | |
|                 x = mx.reshape(x, (-1,))
 | |
|                 x = x.T.T
 | |
|                 x = mx.stop_gradient(x)
 | |
|                 x = mx.abs(x)
 | |
|             return x
 | |
| 
 | |
|         x = mx.zeros((4096, 4096))
 | |
|         mx.eval(x)
 | |
|         pre = mx.get_peak_memory()
 | |
|         out = fun(x)
 | |
|         del x
 | |
|         mx.eval(out)
 | |
|         post = mx.get_peak_memory()
 | |
|         self.assertEqual(pre, post)
 | |
| 
 | |
|         def fun(x):
 | |
|             for _ in range(10):
 | |
|                 x = mx.abs(x)
 | |
|                 x = x[:-1]
 | |
|                 x = mx.abs(x)
 | |
|             return x
 | |
| 
 | |
|         x = mx.zeros((4096 * 4096,))
 | |
|         mx.eval(x)
 | |
|         pre = mx.get_peak_memory()
 | |
|         out = fun(x)
 | |
|         del x
 | |
|         mx.eval(out)
 | |
|         post = mx.get_peak_memory()
 | |
|         self.assertEqual(pre, post)
 | |
| 
 | |
|     @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
 | |
|     def test_multistream_deadlock(self):
 | |
|         s1 = mx.default_stream(mx.gpu)
 | |
|         s2 = mx.new_stream(mx.gpu)
 | |
| 
 | |
|         x = mx.array(1.0)
 | |
|         x = mx.abs(x, stream=s1)
 | |
|         for _ in range(1000):
 | |
|             x = mx.abs(x, stream=s2)
 | |
|         mx.eval(x)
 | |
| 
 | |
|         s1 = mx.default_stream(mx.gpu)
 | |
|         s2 = mx.new_stream(mx.gpu)
 | |
|         old_limit = mx.set_memory_limit(1000)
 | |
| 
 | |
|         x = mx.ones((512, 512), stream=s2)
 | |
|         for _ in range(80):
 | |
|             x = mx.abs(x, stream=s1)
 | |
|         y = mx.abs(x, stream=s2)
 | |
|         z = mx.abs(y, stream=s2)
 | |
|         mx.eval(z)
 | |
|         mx.set_memory_limit(old_limit)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
