mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
move memory APIs into top level mlx.core (#1982)
This commit is contained in:
@@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
mx.eval(x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
all_sum_only = mx.metal.get_peak_memory()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
|
@@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
peak_1 = mx.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
peak_2 = mx.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
@@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
peak_1 = mx.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
peak_2 = mx.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
|
@@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
def test_leaks(self):
|
||||
gc.collect()
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
peak_mem = mx.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||
|
||||
def test_async_eval_with_multiple_streams(self):
|
||||
x = mx.array([1.0])
|
||||
@@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096, 4096))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
def fun(x):
|
||||
@@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096 * 4096,))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
old_limit = mx.metal.set_memory_limit(1000)
|
||||
old_limit = mx.set_memory_limit(1000)
|
||||
|
||||
x = mx.ones((512, 512), stream=s2)
|
||||
for _ in range(80):
|
||||
@@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = mx.abs(x, stream=s2)
|
||||
z = mx.abs(y, stream=s2)
|
||||
mx.eval(z)
|
||||
mx.metal.set_memory_limit(old_limit)
|
||||
mx.set_memory_limit(old_limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
def test_leaks(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.save(save_file, x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.load(save_file)
|
||||
mx.eval(y)
|
||||
load_only = mx.metal.get_peak_memory()
|
||||
load_only = mx.get_peak_memory()
|
||||
y = mx.load(save_file) * scale
|
||||
mx.eval(y)
|
||||
load_with_binary = mx.metal.get_peak_memory()
|
||||
load_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(load_only, load_with_binary)
|
||||
|
||||
|
60
python/tests/test_memory.py
Normal file
60
python/tests/test_memory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMemory(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.set_memory_limit(10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
mx.synchronize()
|
||||
active_mem = mx.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
mx.synchronize()
|
||||
|
||||
new_active_mem = mx.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.clear_cache()
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
|
||||
mx.reset_peak_memory()
|
||||
self.assertEqual(mx.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.set_wired_limit(1000)
|
||||
old_limit = mx.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -1,60 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
mx.synchronize()
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
mx.synchronize()
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.metal.clear_cache()
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
self.assertEqual(mx.metal.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.metal.set_wired_limit(1000)
|
||||
old_limit = mx.metal.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.metal.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.eval(fn(2))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem2 = mx.metal.get_peak_memory()
|
||||
mem2 = mx.get_peak_memory()
|
||||
mx.eval(fn(4))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem4 = mx.metal.get_peak_memory()
|
||||
mem4 = mx.get_peak_memory()
|
||||
self.assertEqual(mem2, mem4)
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
|
@@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
Reference in New Issue
Block a user