mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx_tests
 | |
| 
 | |
| 
 | |
| class TestMemory(mlx_tests.MLXTestCase):
 | |
|     def test_memory_info(self):
 | |
|         old_limit = mx.set_cache_limit(0)
 | |
| 
 | |
|         a = mx.zeros((4096,))
 | |
|         mx.eval(a)
 | |
|         del a
 | |
|         self.assertEqual(mx.get_cache_memory(), 0)
 | |
|         self.assertEqual(mx.set_cache_limit(old_limit), 0)
 | |
|         self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
 | |
| 
 | |
|         old_limit = mx.set_memory_limit(10)
 | |
|         self.assertTrue(mx.set_memory_limit(old_limit), 10)
 | |
|         self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
 | |
| 
 | |
|         # Query active and peak memory
 | |
|         a = mx.zeros((4096,))
 | |
|         mx.eval(a)
 | |
|         mx.synchronize()
 | |
|         active_mem = mx.get_active_memory()
 | |
|         self.assertTrue(active_mem >= 4096 * 4)
 | |
| 
 | |
|         b = mx.zeros((4096,))
 | |
|         mx.eval(b)
 | |
|         del b
 | |
|         mx.synchronize()
 | |
| 
 | |
|         new_active_mem = mx.get_active_memory()
 | |
|         self.assertEqual(new_active_mem, active_mem)
 | |
|         peak_mem = mx.get_peak_memory()
 | |
|         self.assertTrue(peak_mem >= 4096 * 8)
 | |
| 
 | |
|         if mx.metal.is_available():
 | |
|             cache_mem = mx.get_cache_memory()
 | |
|             self.assertTrue(cache_mem >= 4096 * 4)
 | |
| 
 | |
|         mx.clear_cache()
 | |
|         self.assertEqual(mx.get_cache_memory(), 0)
 | |
| 
 | |
|         mx.reset_peak_memory()
 | |
|         self.assertEqual(mx.get_peak_memory(), 0)
 | |
| 
 | |
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
 | |
|     def test_wired_memory(self):
 | |
|         old_limit = mx.set_wired_limit(1000)
 | |
|         old_limit = mx.set_wired_limit(0)
 | |
|         self.assertEqual(old_limit, 1000)
 | |
| 
 | |
|         max_size = mx.metal.device_info()["max_recommended_working_set_size"]
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.set_wired_limit(max_size + 10)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
