mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix Distributed Communication documentation (#1731)
* Add missing `size()` method call for group
This commit is contained in:
parent
8ecdfb718b
commit
92ec632ad5
@ -141,12 +141,13 @@ everything else remaining the same.
|
|||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
def all_reduce_grads(grads):
|
def all_reduce_grads(grads):
|
||||||
N = mx.distributed.init()
|
N = mx.distributed.init().size()
|
||||||
if N == 1:
|
if N == 1:
|
||||||
return grads
|
return grads
|
||||||
return tree_map(
|
return tree_map(
|
||||||
lambda x: mx.distributed.all_sum(x) / N,
|
lambda x: mx.distributed.all_sum(x) / N,
|
||||||
grads)
|
grads
|
||||||
|
)
|
||||||
|
|
||||||
def step(model, x, y):
|
def step(model, x, y):
|
||||||
loss, grads = loss_grad_fn(model, x, y)
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
Loading…
Reference in New Issue
Block a user