Fix average_stats

This commit is contained in:
Angelos Katharopoulos 2025-02-28 16:00:06 -08:00
parent d20413a54d
commit 66b630b4f6

View File

@ -54,10 +54,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
mx.eval(stats, count)
count = count.item()
return [s / count for s in stats.tolist()]
return (stats / count).tolist()
state = [model.state, optimizer.state]