synch before reading memory in test

This commit is contained in:
Awni Hannun 2025-03-04 20:25:12 -08:00 committed by Awni Hannun
parent f140792f1c
commit 67ec27d515
3 changed files with 11 additions and 12 deletions

View File

@ -28,8 +28,7 @@ void finalize(
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([s,
buffers = std::move(retain_buffers),
temps = std::move(encoder.temporaries())]() {
});
temps = std::move(encoder.temporaries())]() {});
}
} // namespace mlx::core::cpu

View File

@ -65,21 +65,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
def test_donation(self):
x = mx.random.normal((1024,))
scale = mx.array(2.0)
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
# Everything should be donated so peak memory is unchanged
x = mx.distributed.all_sum(x) * scale
mx.eval(x)
mx.synchronize()
self.assertEqual(mx.get_peak_memory(), 0)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere

View File

@ -391,9 +391,12 @@ class TestLoad(mlx_tests.MLXTestCase):
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
mx.synchronize()
load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
mx.synchronize()
load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary)