diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index a76b8336..63ca58bb 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -159,8 +159,8 @@ def evaluate( ntokens += toks mx.eval(all_losses, ntokens) - all_losses = mx.distributed.all_sum(all_losses) - ntokens = mx.distributed.all_sum(ntokens) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) return (all_losses / ntokens).item() @@ -272,9 +272,9 @@ def train( if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses).item() + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() - n_tokens = mx.distributed.all_sum(n_tokens).item() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 6ba81628..a6d53747 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule @contextmanager def swapped_with_identity(obj, func): old_func = getattr(obj, func) - setattr(obj, func, lambda x: x) + setattr(obj, func, lambda x, **kwargs: x) yield setattr(obj, func, old_func)