Add stats and limit to common allocator and enable tests (#1988)

* add stats to common allocator and enable tests

* linux memory and default

* fix
This commit is contained in:
Awni Hannun
2025-03-21 12:28:36 -07:00
committed by GitHub
parent d343782c8b
commit 2a980a76ce
13 changed files with 151 additions and 68 deletions

View File

@@ -177,17 +177,17 @@ class TestDistributed(mlx_tests.MLXTestCase):
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)