From 3dcb286baf09d8099d7c9f6ce2f921a75becabca Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Aug 2025 15:56:29 -0700 Subject: [PATCH] Remove stream from average grads so it uses default (#2532) * Remove stream from average grads so it uses default * comment --- python/mlx/nn/utils.py | 8 +++++--- python/tests/nccl_test_distributed.py | 7 +++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 8c786454f..97354d112 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -76,7 +76,7 @@ def average_gradients( group: Optional[mx.distributed.Group] = None, all_reduce_size: int = 32 * 1024**2, communication_type: Optional[mx.Dtype] = None, - stream: mx.Stream = mx.cpu, + communication_stream: Optional[mx.Stream] = None, ): """Average the gradients across the distributed processes in the passed group. @@ -95,7 +95,9 @@ def average_gradients( communication_type (Optional[mlx.core.Dtype]): If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: ``None``. - stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``. + communication_stream (Optional[mlx.core.Stream]): The stream to usse + for the communication. If unspecified the default communication + stream is used which can vary by back-end. Default: ``None``. """ group = group or mx.distributed.init() N = group.size() @@ -106,7 +108,7 @@ def average_gradients( def _average(x): dt = x.dtype x = x.astype(communication_type) if communication_type is not None else x - return mx.distributed.all_sum(x, stream=stream).astype(dt) / N + return mx.distributed.all_sum(x, stream=communication_stream).astype(dt) / N if all_reduce_size <= 0: return tree_map(_average, gradients) diff --git a/python/tests/nccl_test_distributed.py b/python/tests/nccl_test_distributed.py index c55fb5c1f..c9461d118 100644 --- a/python/tests/nccl_test_distributed.py +++ b/python/tests/nccl_test_distributed.py @@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): mx.distributed.all_sum = new_all_sum try: grads = [mx.ones(10) for i in range(10)] - new_grads = average_gradients(grads, stream=mx.gpu) + new_grads = average_gradients(grads) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) self.assertEqual(n_calls, 1) n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu) + new_grads = average_gradients(grads, all_reduce_size=4 * 50) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) self.assertEqual(n_calls, 2) n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu) + new_grads = average_gradients(grads, all_reduce_size=0) mx.eval(new_grads) self.assertEqual(len(new_grads), 10) self.assertTrue(all(mx.all(g == 1) for g in new_grads)) @@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase): grads, all_reduce_size=2 * 50, communication_type=mx.float16, - stream=mx.gpu, ) mx.eval(new_grads) self.assertEqual(len(new_grads), 10)