mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default * comment
This commit is contained in:
		| @@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads, stream=mx.gpu) | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu) | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu) | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
| @@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): | ||||
|                 grads, | ||||
|                 all_reduce_size=2 * 50, | ||||
|                 communication_type=mx.float16, | ||||
|                 stream=mx.gpu, | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun