diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index a76b8336..8269e547 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -160,7 +160,8 @@ def evaluate( mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses) - ntokens = mx.distributed.all_sum(ntokens) + stream = mx.cpu if mx.distributed.init().size() > 1 else None + ntokens = mx.distributed.all_sum(ntokens, stream=stream) return (all_losses / ntokens).item()