diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 3b7cd7977..dd9cab983 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -28,8 +28,7 @@ void finalize( auto& encoder = cpu::get_command_encoder(s); encoder.dispatch([s, buffers = std::move(retain_buffers), - temps = std::move(encoder.temporaries())]() { - }); + temps = std::move(encoder.temporaries())]() {}); } } // namespace mlx::core::cpu diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py index 5feb51bc9..5926489da 100644 --- a/python/tests/mlx_distributed_tests.py +++ b/python/tests/mlx_distributed_tests.py @@ -65,21 +65,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): def test_donation(self): x = mx.random.normal((1024,)) + scale = mx.array(2.0) mx.eval(x) mx.synchronize() mx.reset_peak_memory() - scale = mx.array(2.0) - y = mx.distributed.all_sum(x) - mx.eval(y) - mx.synchronize() - all_sum_only = mx.get_peak_memory() - y = mx.distributed.all_sum(x) * scale - mx.eval(y) - mx.synchronize() - all_sum_with_binary = mx.get_peak_memory() - self.assertEqual(all_sum_only, all_sum_with_binary) + # Everything should be donated so peak memory is unchanged + x = mx.distributed.all_sum(x) * scale + mx.eval(x) + mx.synchronize() + + self.assertEqual(mx.get_peak_memory(), 0) def test_shard_linear(self): # Seed the prng to have the same inputs and weights generated everywhere diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 341564dae..6105ed38f 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -391,9 +391,12 @@ class TestLoad(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) + + mx.synchronize() load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) + mx.synchronize() load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary)