mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
move memory APIs into top level mlx.core (#1982)
This commit is contained in:
60
python/tests/test_memory.py
Normal file
60
python/tests/test_memory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMemory(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.set_memory_limit(10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
mx.synchronize()
|
||||
active_mem = mx.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
mx.synchronize()
|
||||
|
||||
new_active_mem = mx.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.clear_cache()
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
|
||||
mx.reset_peak_memory()
|
||||
self.assertEqual(mx.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.set_wired_limit(1000)
|
||||
old_limit = mx.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user