mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 03:40:22 +08:00
Fix average_stats
This commit is contained in:
parent
d20413a54d
commit
66b630b4f6
@ -54,10 +54,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
with mx.stream(mx.cpu):
|
||||
stats = mx.distributed.all_sum(mx.array(stats))
|
||||
count = mx.distributed.all_sum(count)
|
||||
mx.eval(stats, count)
|
||||
count = count.item()
|
||||
|
||||
return [s / count for s in stats.tolist()]
|
||||
return (stats / count).tolist()
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user