mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Fix data_iter in prepare_dataset from speechcommands example (#1113)
This commit is contained in:
parent
eb9277f574
commit
0ca162cfb2
@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
samples_per_sec = []
|
samples_per_sec = []
|
||||||
|
|
||||||
model.train(True)
|
model.train(True)
|
||||||
|
train_iter.reset()
|
||||||
for batch_counter, batch in enumerate(train_iter):
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
@ -111,6 +112,7 @@ def test_epoch(model, test_iter):
|
|||||||
model.train(False)
|
model.train(False)
|
||||||
accs = []
|
accs = []
|
||||||
throughput = []
|
throughput = []
|
||||||
|
test_iter.reset()
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
|
Loading…
Reference in New Issue
Block a user