mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 16:16:27 +08:00
Fix the throughput calculation
This commit is contained in:
parent
8a76b421a0
commit
14faec4ca2
@ -81,7 +81,10 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
samples_per_sec += x.shape[0] / (toc - tic)
|
||||
count += 1
|
||||
if batch_counter % 10 == 0:
|
||||
l, a, s = average_stats([losses, accuracies, samples_per_sec], count)
|
||||
l, a, s = average_stats(
|
||||
[losses, accuracies, world.size() * samples_per_sec],
|
||||
count,
|
||||
)
|
||||
print_zero(
|
||||
world,
|
||||
" | ".join(
|
||||
@ -94,7 +97,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
),
|
||||
)
|
||||
|
||||
return average_stats([losses, accuracies, samples_per_sec], count)
|
||||
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
||||
|
||||
|
||||
def test_epoch(model, test_iter, epoch):
|
||||
|
Loading…
Reference in New Issue
Block a user