diff --git a/speechcommands/main.py b/speechcommands/main.py index 3f0de98c..ed328f4c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -48,7 +48,6 @@ def prepare_dataset(batch_size, split, root=None): .batch(batch_size) .to_stream() .prefetch(4, 4) - .to_buffer() ) return data_iter @@ -77,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec = [] model.train(True) + train_iter.reset() for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -112,6 +112,7 @@ def test_epoch(model, test_iter): model.train(False) accs = [] throughput = [] + test_iter.reset() for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"])