diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 8c786454f..e42ce994a 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -76,7 +76,6 @@ def average_gradients( group: Optional[mx.distributed.Group] = None, all_reduce_size: int = 32 * 1024**2, communication_type: Optional[mx.Dtype] = None, - stream: mx.Stream = mx.cpu, ): """Average the gradients across the distributed processes in the passed group. @@ -95,7 +94,6 @@ def average_gradients( communication_type (Optional[mlx.core.Dtype]): If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: ``None``. - stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``. """ group = group or mx.distributed.init() N = group.size() @@ -106,7 +104,7 @@ def average_gradients( def _average(x): dt = x.dtype x = x.astype(communication_type) if communication_type is not None else x - return mx.distributed.all_sum(x, stream=stream).astype(dt) / N + return mx.distributed.all_sum(x).astype(dt) / N if all_reduce_size <= 0: return tree_map(_average, gradients) diff --git a/python/tests/nccl_test_distributed.py b/python/tests/nccl_test_distributed.py index c55fb5c1f..c9461d118 100644 --- a/python/tests/nccl_test_distributed.py +++ b/python/tests/nccl_test_distributed.py @@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): mx.distributed.all_sum = new_all_sum try: grads = [mx.ones(10) for i in range(10)] - new_grads = average_gradients(grads, stream=mx.gpu) + new_grads = average_gradients(grads) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) self.assertEqual(n_calls, 1) n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu) + new_grads = average_gradients(grads, all_reduce_size=4 * 50) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) self.assertEqual(n_calls, 2) n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu) + new_grads = average_gradients(grads, all_reduce_size=0) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) @@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): grads, all_reduce_size=2 * 50, communication_type=mx.float16, - stream=mx.gpu, ) mx.eval(new_grads) self.assertEqual(len(new_grads), 10)