mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
synch before reading memory in test
This commit is contained in:
parent
f140792f1c
commit
67ec27d515
@ -28,8 +28,7 @@ void finalize(
|
|||||||
auto& encoder = cpu::get_command_encoder(s);
|
auto& encoder = cpu::get_command_encoder(s);
|
||||||
encoder.dispatch([s,
|
encoder.dispatch([s,
|
||||||
buffers = std::move(retain_buffers),
|
buffers = std::move(retain_buffers),
|
||||||
temps = std::move(encoder.temporaries())]() {
|
temps = std::move(encoder.temporaries())]() {});
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
} // namespace mlx::core::cpu
|
||||||
|
@ -65,21 +65,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_donation(self):
|
def test_donation(self):
|
||||||
x = mx.random.normal((1024,))
|
x = mx.random.normal((1024,))
|
||||||
|
scale = mx.array(2.0)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
mx.synchronize()
|
mx.synchronize()
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
|
||||||
y = mx.distributed.all_sum(x)
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_only = mx.get_peak_memory()
|
|
||||||
y = mx.distributed.all_sum(x) * scale
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
# Everything should be donated so peak memory is unchanged
|
||||||
|
x = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(x)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
self.assertEqual(mx.get_peak_memory(), 0)
|
||||||
|
|
||||||
def test_shard_linear(self):
|
def test_shard_linear(self):
|
||||||
# Seed the prng to have the same inputs and weights generated everywhere
|
# Seed the prng to have the same inputs and weights generated everywhere
|
||||||
|
@ -391,9 +391,12 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
y = mx.load(save_file)
|
y = mx.load(save_file)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
mx.synchronize()
|
||||||
load_only = mx.get_peak_memory()
|
load_only = mx.get_peak_memory()
|
||||||
y = mx.load(save_file) * scale
|
y = mx.load(save_file) * scale
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
load_with_binary = mx.get_peak_memory()
|
load_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
self.assertEqual(load_only, load_with_binary)
|
self.assertEqual(load_only, load_with_binary)
|
||||||
|
Loading…
Reference in New Issue
Block a user