From 3bb7b671a9ed704aa871e05c0e48f8fa4294d68c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Aug 2025 15:03:12 -0700 Subject: [PATCH] comment --- python/mlx/nn/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index e42ce994a..97354d112 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -76,6 +76,7 @@ def average_gradients( group: Optional[mx.distributed.Group] = None, all_reduce_size: int = 32 * 1024**2, communication_type: Optional[mx.Dtype] = None, + communication_stream: Optional[mx.Stream] = None, ): """Average the gradients across the distributed processes in the passed group. @@ -94,6 +95,9 @@ def average_gradients( communication_type (Optional[mlx.core.Dtype]): If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: ``None``. + communication_stream (Optional[mlx.core.Stream]): The stream to usse + for the communication. If unspecified the default communication + stream is used which can vary by back-end. Default: ``None``. """ group = group or mx.distributed.init() N = group.size() @@ -104,7 +108,7 @@ def average_gradients( def _average(x): dt = x.dtype x = x.astype(communication_type) if communication_type is not None else x - return mx.distributed.all_sum(x).astype(dt) / N + return mx.distributed.all_sum(x, stream=communication_stream).astype(dt) / N if all_reduce_size <= 0: return tree_map(_average, gradients)