mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:46:09 +08:00
add reset()
This commit is contained in:
parent
04c18832ab
commit
43d023948f
@ -48,7 +48,6 @@ def prepare_dataset(batch_size, split, root=None):
|
||||
.batch(batch_size)
|
||||
.to_stream()
|
||||
.prefetch(4, 4)
|
||||
.to_buffer()
|
||||
)
|
||||
return data_iter
|
||||
|
||||
@ -77,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
samples_per_sec = []
|
||||
|
||||
model.train(True)
|
||||
train_iter.reset()
|
||||
for batch_counter, batch in enumerate(train_iter):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
@ -112,6 +112,7 @@ def test_epoch(model, test_iter):
|
||||
model.train(False)
|
||||
accs = []
|
||||
throughput = []
|
||||
test_iter.reset()
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
|
Loading…
Reference in New Issue
Block a user