diff --git a/speechcommands/main.py b/speechcommands/main.py index 0d8da9fd..ed328f4c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec = [] model.train(True) + train_iter.reset() for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -111,6 +112,7 @@ def test_epoch(model, test_iter): model.train(False) accs = [] throughput = [] + test_iter.reset() for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"])