From 2e08e8b96c1179febb958a9805ab90908b1b517d Mon Sep 17 00:00:00 2001 From: ivanfioravanti Date: Mon, 13 Jan 2025 23:06:58 +0000 Subject: [PATCH] moving all distributed ops to cpu --- llms/mlx_lm/tuner/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 8269e547..63ca58bb 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -159,9 +159,8 @@ def evaluate( ntokens += toks mx.eval(all_losses, ntokens) - all_losses = mx.distributed.all_sum(all_losses) - stream = mx.cpu if mx.distributed.init().size() > 1 else None - ntokens = mx.distributed.all_sum(ntokens, stream=stream) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) return (all_losses / ntokens).item() @@ -273,9 +272,9 @@ def train( if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses).item() + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() - n_tokens = mx.distributed.all_sum(n_tokens).item() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start)