Fix data_iter in prepare_dataset from speechcommands example (#1113)

This commit is contained in:
sakares saengkaew 2024-12-03 14:56:07 +07:00 committed by GitHub
parent eb9277f574
commit 0ca162cfb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
samples_per_sec = []
model.train(True)
train_iter.reset()
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
@ -111,6 +112,7 @@ def test_epoch(model, test_iter):
model.train(False)
accs = []
throughput = []
test_iter.reset()
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])