mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix Distributed Communication documentation (#1731)
* Add missing `size()` method call for group
This commit is contained in:
		| @@ -141,12 +141,13 @@ everything else remaining the same. | ||||
|     from mlx.utils import tree_map | ||||
|  | ||||
|     def all_reduce_grads(grads): | ||||
|         N = mx.distributed.init() | ||||
|         N = mx.distributed.init().size() | ||||
|         if N == 1: | ||||
|             return grads | ||||
|         return tree_map( | ||||
|                 lambda x: mx.distributed.all_sum(x) / N, | ||||
|                 grads) | ||||
|             lambda x: mx.distributed.all_sum(x) / N, | ||||
|             grads | ||||
|         ) | ||||
|  | ||||
|     def step(model, x, y): | ||||
|         loss, grads = loss_grad_fn(model, x, y) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Danilo Peixoto
					Danilo Peixoto