diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index 702951a0c..ec3accab3 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -141,12 +141,13 @@ everything else remaining the same. from mlx.utils import tree_map def all_reduce_grads(grads): - N = mx.distributed.init() + N = mx.distributed.init().size() if N == 1: return grads return tree_map( - lambda x: mx.distributed.all_sum(x) / N, - grads) + lambda x: mx.distributed.all_sum(x) / N, + grads + ) def step(model, x, y): loss, grads = loss_grad_fn(model, x, y)