mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests * linux memory and default * fix
This commit is contained in:
@@ -36,7 +36,7 @@ void init_metal(nb::module_& m) {
|
||||
});
|
||||
metal.def("get_peak_memory", []() {
|
||||
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
||||
return mx::get_active_memory();
|
||||
return mx::get_peak_memory();
|
||||
});
|
||||
metal.def("reset_peak_memory", []() {
|
||||
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
||||
|
@@ -177,17 +177,17 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
@@ -1803,7 +1803,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_multi_output_leak(self):
|
||||
def fun():
|
||||
a = mx.zeros((2**20))
|
||||
|
@@ -745,11 +745,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.custom_function,
|
||||
mx.checkpoint,
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
mx.synchronize()
|
||||
mem_pre = mx.get_active_memory()
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
@@ -763,12 +760,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
mem_post = mx.get_active_memory()
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
def test_grad_with_copies(self):
|
||||
|
@@ -117,10 +117,9 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||
peak_mem = mx.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||
|
||||
def test_async_eval_with_multiple_streams(self):
|
||||
x = mx.array([1.0])
|
||||
@@ -137,7 +136,6 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
mx.async_eval(x)
|
||||
mx.eval(a + b)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_donation_for_noops(self):
|
||||
def fun(x):
|
||||
s = x.shape
|
||||
|
@@ -385,7 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.eval(x)
|
||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||
mx.save(save_file, x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
|
@@ -7,7 +7,6 @@ import mlx_tests
|
||||
|
||||
|
||||
class TestMemory(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.set_cache_limit(0)
|
||||
|
||||
@@ -38,8 +37,10 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
if mx.metal.is_available():
|
||||
cache_mem = mx.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.clear_cache()
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
@@ -47,6 +48,8 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
mx.reset_peak_memory()
|
||||
self.assertEqual(mx.get_peak_memory(), 0)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_wired_memory(self):
|
||||
old_limit = mx.set_wired_limit(1000)
|
||||
old_limit = mx.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
@@ -1901,12 +1901,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.cumsum(x)
|
||||
return x
|
||||
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
mx.eval(fn(2))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
mem2 = mx.get_peak_memory()
|
||||
mx.eval(fn(4))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.synchronize()
|
||||
mem4 = mx.get_peak_memory()
|
||||
self.assertEqual(mem2, mem4)
|
||||
|
||||
|
Reference in New Issue
Block a user