move memory APIs into top level mlx.core (#1982)

This commit is contained in:
Awni Hannun
2025-03-21 07:25:12 -07:00
committed by GitHub
parent 65a38c452b
commit 4e1994e9d7
25 changed files with 418 additions and 323 deletions

View File

@@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_only = mx.metal.get_peak_memory()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_with_binary = mx.metal.get_peak_memory()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)

View File

@@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase):
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
peak_1 = mx.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
peak_2 = mx.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def fun():
@@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase):
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
peak_1 = mx.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
peak_2 = mx.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):

View File

@@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
]:
if mx.metal.is_available():
mx.synchronize(mx.default_stream(mx.default_device()))
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase):
def test_leaks(self):
gc.collect()
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
out = mx.vjp(fn, (x,), (y,))
out = mx.vjp(fn, (x,), (y,))
if mx.metal.is_available():
peak_mem = mx.metal.get_peak_memory()
peak_mem = mx.get_peak_memory()
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
self.assertEqual(peak_mem, mx.get_peak_memory())
def test_async_eval_with_multiple_streams(self):
x = mx.array([1.0])
@@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
x = mx.zeros((4096, 4096))
mx.eval(x)
pre = mx.metal.get_peak_memory()
pre = mx.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
post = mx.get_peak_memory()
self.assertEqual(pre, post)
def fun(x):
@@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
x = mx.zeros((4096 * 4096,))
mx.eval(x)
pre = mx.metal.get_peak_memory()
pre = mx.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
old_limit = mx.metal.set_memory_limit(1000)
old_limit = mx.set_memory_limit(1000)
x = mx.ones((512, 512), stream=s2)
for _ in range(80):
@@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
y = mx.abs(x, stream=s2)
z = mx.abs(y, stream=s2)
mx.eval(z)
mx.metal.set_memory_limit(old_limit)
mx.set_memory_limit(old_limit)
if __name__ == "__main__":

View File

@@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
def test_leaks(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.save(save_file, x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
load_only = mx.metal.get_peak_memory()
load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
load_with_binary = mx.metal.get_peak_memory()
load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary)

View File

@@ -0,0 +1,60 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMemory(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.get_cache_memory(), 0)
self.assertEqual(mx.set_cache_limit(old_limit), 0)
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
old_limit = mx.set_memory_limit(10)
self.assertTrue(mx.set_memory_limit(old_limit), 10)
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.clear_cache()
self.assertEqual(mx.get_cache_memory(), 0)
mx.reset_peak_memory()
self.assertEqual(mx.get_peak_memory(), 0)
old_limit = mx.set_wired_limit(1000)
old_limit = mx.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.set_wired_limit(max_size + 10)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,60 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.metal.clear_cache()
self.assertEqual(mx.metal.get_cache_memory(), 0)
mx.metal.reset_peak_memory()
self.assertEqual(mx.metal.get_peak_memory(), 0)
old_limit = mx.metal.set_wired_limit(1000)
old_limit = mx.metal.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.metal.set_wired_limit(max_size + 10)
if __name__ == "__main__":
unittest.main()

View File

@@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase):
mx.synchronize(mx.default_stream(mx.default_device()))
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mem2 = mx.metal.get_peak_memory()
mem2 = mx.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mem4 = mx.metal.get_peak_memory()
mem4 = mx.get_peak_memory()
self.assertEqual(mem2, mem4)
def test_squeeze_expand(self):

View File

@@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase):
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0