mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	synch before reading memory in test
This commit is contained in:
		| @@ -28,8 +28,7 @@ void finalize( | ||||
|   auto& encoder = cpu::get_command_encoder(s); | ||||
|   encoder.dispatch([s, | ||||
|                     buffers = std::move(retain_buffers), | ||||
|                     temps = std::move(encoder.temporaries())]() { | ||||
|   }); | ||||
|                     temps = std::move(encoder.temporaries())]() {}); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cpu | ||||
|   | ||||
| @@ -65,21 +65,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         scale = mx.array(2.0) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         mx.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_only = mx.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|         # Everything should be donated so peak memory is unchanged | ||||
|         x = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         self.assertEqual(mx.get_peak_memory(), 0) | ||||
|  | ||||
|     def test_shard_linear(self): | ||||
|         # Seed the prng to have the same inputs and weights generated everywhere | ||||
|   | ||||
| @@ -391,9 +391,12 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.load(save_file) | ||||
|         mx.eval(y) | ||||
|  | ||||
|         mx.synchronize() | ||||
|         load_only = mx.get_peak_memory() | ||||
|         y = mx.load(save_file) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         load_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(load_only, load_with_binary) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun