add reset()

This commit is contained in:
Sakares Saengkaew 2024-12-03 15:53:00 +08:00
parent 04c18832ab
commit 43d023948f
Failed to extract signature

View File

@ -48,7 +48,6 @@ def prepare_dataset(batch_size, split, root=None):
.batch(batch_size) .batch(batch_size)
.to_stream() .to_stream()
.prefetch(4, 4) .prefetch(4, 4)
.to_buffer()
) )
return data_iter return data_iter
@ -77,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
samples_per_sec = [] samples_per_sec = []
model.train(True) model.train(True)
train_iter.reset()
for batch_counter, batch in enumerate(train_iter): for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
@ -112,6 +112,7 @@ def test_epoch(model, test_iter):
model.train(False) model.train(False)
accs = [] accs = []
throughput = [] throughput = []
test_iter.reset()
for batch_counter, batch in enumerate(test_iter): for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])