reduction moved to CPU in case of distributed training (#1200)

This commit is contained in:
Ivan Fioravanti 2025-01-15 02:20:42 +01:00 committed by GitHub
parent c117af83b8
commit 6ae6c72c2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -159,8 +159,8 @@ def evaluate(
ntokens += toks ntokens += toks
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
return (all_losses / ntokens).item() return (all_losses / ntokens).item()
@ -272,9 +272,9 @@ def train(
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)

View File

@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
@contextmanager @contextmanager
def swapped_with_identity(obj, func): def swapped_with_identity(obj, func):
old_func = getattr(obj, func) old_func = getattr(obj, func)
setattr(obj, func, lambda x: x) setattr(obj, func, lambda x, **kwargs: x)
yield yield
setattr(obj, func, old_func) setattr(obj, func, old_func)