mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix Distributed Communication documentation (#1731)
* Add missing `size()` method call for group
This commit is contained in:
parent
8ecdfb718b
commit
92ec632ad5
@ -141,12 +141,13 @@ everything else remaining the same.
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
N = mx.distributed.init().size()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads
|
||||
)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
|
Loading…
Reference in New Issue
Block a user