diff --git a/cifar/main.py b/cifar/main.py index 3fe5d2e0..ac010636 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -54,10 +54,7 @@ def train_epoch(model, train_iter, optimizer, epoch): with mx.stream(mx.cpu): stats = mx.distributed.all_sum(mx.array(stats)) count = mx.distributed.all_sum(count) - mx.eval(stats, count) - count = count.item() - - return [s / count for s in stats.tolist()] + return (stats / count).tolist() state = [model.state, optimizer.state]