Remove stream from average grads so it uses default (#2532)

* Remove stream from average grads so it uses default

* comment
This commit is contained in:
Awni Hannun 2025-08-25 15:56:29 -07:00 committed by GitHub
parent 4822c3dbe9
commit 3dcb286baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 7 deletions

View File

@ -76,7 +76,7 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
communication_stream: Optional[mx.Stream] = None,
):
"""Average the gradients across the distributed processes in the passed group.
@ -95,7 +95,9 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
communication_stream (Optional[mlx.core.Stream]): The stream to usse
for the communication. If unspecified the default communication
stream is used which can vary by back-end. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()
@ -106,7 +108,7 @@ def average_gradients(
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
return mx.distributed.all_sum(x, stream=communication_stream).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)

View File

@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads, stream=mx.gpu)
new_grads = average_gradients(grads)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
new_grads = average_gradients(grads, all_reduce_size=0)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
stream=mx.gpu,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)