mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
reduction moved to CPU in case of distributed training
This commit is contained in:
parent
514502da22
commit
ff1719afc3
@ -160,7 +160,8 @@ def evaluate(
|
|||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses)
|
all_losses = mx.distributed.all_sum(all_losses)
|
||||||
ntokens = mx.distributed.all_sum(ntokens)
|
stream = mx.cpu if mx.distributed.init().size() > 1 else None
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens, stream=stream)
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user