Fix the throughput calculation

This commit is contained in:
Angelos Katharopoulos 2025-02-25 17:20:15 -08:00
parent 8a76b421a0
commit 14faec4ca2

View File

@ -81,7 +81,10 @@ def train_epoch(model, train_iter, optimizer, epoch):
samples_per_sec += x.shape[0] / (toc - tic)
count += 1
if batch_counter % 10 == 0:
l, a, s = average_stats([losses, accuracies, samples_per_sec], count)
l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
" | ".join(
@ -94,7 +97,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
),
)
return average_stats([losses, accuracies, samples_per_sec], count)
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
def test_epoch(model, test_iter, epoch):