mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 17:37:56 +08:00
Fix the throughput calculation
This commit is contained in:
parent
8a76b421a0
commit
14faec4ca2
@ -81,7 +81,10 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
samples_per_sec += x.shape[0] / (toc - tic)
|
samples_per_sec += x.shape[0] / (toc - tic)
|
||||||
count += 1
|
count += 1
|
||||||
if batch_counter % 10 == 0:
|
if batch_counter % 10 == 0:
|
||||||
l, a, s = average_stats([losses, accuracies, samples_per_sec], count)
|
l, a, s = average_stats(
|
||||||
|
[losses, accuracies, world.size() * samples_per_sec],
|
||||||
|
count,
|
||||||
|
)
|
||||||
print_zero(
|
print_zero(
|
||||||
world,
|
world,
|
||||||
" | ".join(
|
" | ".join(
|
||||||
@ -94,7 +97,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return average_stats([losses, accuracies, samples_per_sec], count)
|
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
||||||
|
|
||||||
|
|
||||||
def test_epoch(model, test_iter, epoch):
|
def test_epoch(model, test_iter, epoch):
|
||||||
|
Loading…
Reference in New Issue
Block a user