mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Fix average_stats
This commit is contained in:
parent
d20413a54d
commit
66b630b4f6
@ -54,10 +54,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
with mx.stream(mx.cpu):
|
with mx.stream(mx.cpu):
|
||||||
stats = mx.distributed.all_sum(mx.array(stats))
|
stats = mx.distributed.all_sum(mx.array(stats))
|
||||||
count = mx.distributed.all_sum(count)
|
count = mx.distributed.all_sum(count)
|
||||||
mx.eval(stats, count)
|
return (stats / count).tolist()
|
||||||
count = count.item()
|
|
||||||
|
|
||||||
return [s / count for s in stats.tolist()]
|
|
||||||
|
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user