diff --git a/cifar/main.py b/cifar/main.py index 27074133..7eb6efdf 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -81,7 +81,10 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec += x.shape[0] / (toc - tic) count += 1 if batch_counter % 10 == 0: - l, a, s = average_stats([losses, accuracies, samples_per_sec], count) + l, a, s = average_stats( + [losses, accuracies, world.size() * samples_per_sec], + count, + ) print_zero( world, " | ".join( @@ -94,7 +97,7 @@ def train_epoch(model, train_iter, optimizer, epoch): ), ) - return average_stats([losses, accuracies, samples_per_sec], count) + return average_stats([losses, accuracies, world.size() * samples_per_sec], count) def test_epoch(model, test_iter, epoch):