mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	synch before reading memory in test
This commit is contained in:
		| @@ -28,8 +28,7 @@ void finalize( | |||||||
|   auto& encoder = cpu::get_command_encoder(s); |   auto& encoder = cpu::get_command_encoder(s); | ||||||
|   encoder.dispatch([s, |   encoder.dispatch([s, | ||||||
|                     buffers = std::move(retain_buffers), |                     buffers = std::move(retain_buffers), | ||||||
|                     temps = std::move(encoder.temporaries())]() { |                     temps = std::move(encoder.temporaries())]() {}); | ||||||
|   }); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace mlx::core::cpu | } // namespace mlx::core::cpu | ||||||
|   | |||||||
| @@ -65,21 +65,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): | |||||||
|  |  | ||||||
|     def test_donation(self): |     def test_donation(self): | ||||||
|         x = mx.random.normal((1024,)) |         x = mx.random.normal((1024,)) | ||||||
|  |         scale = mx.array(2.0) | ||||||
|         mx.eval(x) |         mx.eval(x) | ||||||
|         mx.synchronize() |         mx.synchronize() | ||||||
|  |  | ||||||
|         mx.reset_peak_memory() |         mx.reset_peak_memory() | ||||||
|         scale = mx.array(2.0) |  | ||||||
|         y = mx.distributed.all_sum(x) |  | ||||||
|         mx.eval(y) |  | ||||||
|         mx.synchronize() |  | ||||||
|         all_sum_only = mx.get_peak_memory() |  | ||||||
|         y = mx.distributed.all_sum(x) * scale |  | ||||||
|         mx.eval(y) |  | ||||||
|         mx.synchronize() |  | ||||||
|         all_sum_with_binary = mx.get_peak_memory() |  | ||||||
|  |  | ||||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) |         # Everything should be donated so peak memory is unchanged | ||||||
|  |         x = mx.distributed.all_sum(x) * scale | ||||||
|  |         mx.eval(x) | ||||||
|  |         mx.synchronize() | ||||||
|  |  | ||||||
|  |         self.assertEqual(mx.get_peak_memory(), 0) | ||||||
|  |  | ||||||
|     def test_shard_linear(self): |     def test_shard_linear(self): | ||||||
|         # Seed the prng to have the same inputs and weights generated everywhere |         # Seed the prng to have the same inputs and weights generated everywhere | ||||||
|   | |||||||
| @@ -391,9 +391,12 @@ class TestLoad(mlx_tests.MLXTestCase): | |||||||
|         scale = mx.array(2.0) |         scale = mx.array(2.0) | ||||||
|         y = mx.load(save_file) |         y = mx.load(save_file) | ||||||
|         mx.eval(y) |         mx.eval(y) | ||||||
|  |  | ||||||
|  |         mx.synchronize() | ||||||
|         load_only = mx.get_peak_memory() |         load_only = mx.get_peak_memory() | ||||||
|         y = mx.load(save_file) * scale |         y = mx.load(save_file) * scale | ||||||
|         mx.eval(y) |         mx.eval(y) | ||||||
|  |         mx.synchronize() | ||||||
|         load_with_binary = mx.get_peak_memory() |         load_with_binary = mx.get_peak_memory() | ||||||
|  |  | ||||||
|         self.assertEqual(load_only, load_with_binary) |         self.assertEqual(load_only, load_with_binary) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun