mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fixes output donation for IO ops on the GPU (#1857)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							0a5215693e
						
					
				
				
					commit
					0145911bea
				
			| @@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|  | ||||
|         mx.metal.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         all_sum_only = mx.metal.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         all_sum_with_binary = mx.metal.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|         aload = mx.load(save_file)["a"] | ||||
|         self.assertTrue(mx.array_equal(a, aload)) | ||||
|  | ||||
|     def test_load_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         save_file = os.path.join(self.test_dir, "donation.npy") | ||||
|         mx.save(save_file, x) | ||||
|  | ||||
|         mx.metal.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.load(save_file) | ||||
|         mx.eval(y) | ||||
|         load_only = mx.metal.get_peak_memory() | ||||
|         y = mx.load(save_file) * scale | ||||
|         mx.eval(y) | ||||
|         load_with_binary = mx.metal.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(load_only, load_with_binary) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user