mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
moving all distributed ops to cpu
This commit is contained in:
parent
ff1719afc3
commit
2e08e8b96c
@ -159,9 +159,8 @@ def evaluate(
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
stream = mx.cpu if mx.distributed.init().size() > 1 else None
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=stream)
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
|
||||
return (all_losses / ntokens).item()
|
||||
|
||||
@ -273,9 +272,9 @@ def train(
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses).item()
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
|
Loading…
Reference in New Issue
Block a user