Fix Distributed Communication documentation (#1731)

* Add missing `size()` method call for group
This commit is contained in:
Danilo Peixoto 2025-01-02 19:08:38 -03:00 committed by GitHub
parent 8ecdfb718b
commit 92ec632ad5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -141,12 +141,13 @@ everything else remaining the same.
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)