move memory APIs into top level mlx.core (#1982)

This commit is contained in:
Awni Hannun
2025-03-21 07:25:12 -07:00
committed by GitHub
parent 65a38c452b
commit 4e1994e9d7
25 changed files with 418 additions and 323 deletions

View File

@@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_only = mx.metal.get_peak_memory()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_with_binary = mx.metal.get_peak_memory()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)