From 92ec632ad5cc41ffaff9543ad9e80864f381c009 Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Thu, 2 Jan 2025 19:08:38 -0300 Subject: [PATCH] Fix Distributed Communication documentation (#1731) * Add missing `size()` method call for group --- docs/src/usage/distributed.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index 702951a0c..ec3accab3 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -141,12 +141,13 @@ everything else remaining the same. from mlx.utils import tree_map def all_reduce_grads(grads): - N = mx.distributed.init() + N = mx.distributed.init().size() if N == 1: return grads return tree_map( - lambda x: mx.distributed.all_sum(x) / N, - grads) + lambda x: mx.distributed.all_sum(x) / N, + grads + ) def step(model, x, y): loss, grads = loss_grad_fn(model, x, y)