From ff1719afc3623a4eaecdb0368ab19f5747909ca9 Mon Sep 17 00:00:00 2001 From: ivanfioravanti Date: Sat, 11 Jan 2025 00:32:54 +0100 Subject: [PATCH] reduction moved to CPU in case of distributed training --- llms/mlx_lm/tuner/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index a76b8336..8269e547 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -160,7 +160,8 @@ def evaluate( mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses) - ntokens = mx.distributed.all_sum(ntokens) + stream = mx.cpu if mx.distributed.init().size() > 1 else None + ntokens = mx.distributed.all_sum(ntokens, stream=stream) return (all_losses / ntokens).item()