mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
comment
This commit is contained in:
parent
d08fa4bef8
commit
3bb7b671a9
@ -76,6 +76,7 @@ def average_gradients(
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
all_reduce_size: int = 32 * 1024**2,
|
||||
communication_type: Optional[mx.Dtype] = None,
|
||||
communication_stream: Optional[mx.Stream] = None,
|
||||
):
|
||||
"""Average the gradients across the distributed processes in the passed group.
|
||||
|
||||
@ -94,6 +95,9 @@ def average_gradients(
|
||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||
type before performing the communication. Typically cast to a
|
||||
smaller float to reduce the communication size. Default: ``None``.
|
||||
communication_stream (Optional[mlx.core.Stream]): The stream to usse
|
||||
for the communication. If unspecified the default communication
|
||||
stream is used which can vary by back-end. Default: ``None``.
|
||||
"""
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
@ -104,7 +108,7 @@ def average_gradients(
|
||||
def _average(x):
|
||||
dt = x.dtype
|
||||
x = x.astype(communication_type) if communication_type is not None else x
|
||||
return mx.distributed.all_sum(x).astype(dt) / N
|
||||
return mx.distributed.all_sum(x, stream=communication_stream).astype(dt) / N
|
||||
|
||||
if all_reduce_size <= 0:
|
||||
return tree_map(_average, gradients)
|
||||
|
Loading…
Reference in New Issue
Block a user