reduction moved to CPU in case of distributed training

This commit is contained in:
ivanfioravanti 2025-01-11 00:32:54 +01:00
parent 514502da22
commit ff1719afc3

View File

@ -160,7 +160,8 @@ def evaluate(
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses) all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens) stream = mx.cpu if mx.distributed.init().size() > 1 else None
ntokens = mx.distributed.all_sum(ntokens, stream=stream)
return (all_losses / ntokens).item() return (all_losses / ntokens).item()