mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
reduction moved to CPU in case of distributed training (#1200)
This commit is contained in:
parent
c117af83b8
commit
6ae6c72c2e
@ -159,8 +159,8 @@ def evaluate(
|
|||||||
ntokens += toks
|
ntokens += toks
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
@ -272,9 +272,9 @@ def train(
|
|||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
|
@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def swapped_with_identity(obj, func):
|
def swapped_with_identity(obj, func):
|
||||||
old_func = getattr(obj, func)
|
old_func = getattr(obj, func)
|
||||||
setattr(obj, func, lambda x: x)
|
setattr(obj, func, lambda x, **kwargs: x)
|
||||||
yield
|
yield
|
||||||
setattr(obj, func, old_func)
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user