mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
add reset()
This commit is contained in:
parent
04c18832ab
commit
43d023948f
@ -48,7 +48,6 @@ def prepare_dataset(batch_size, split, root=None):
|
|||||||
.batch(batch_size)
|
.batch(batch_size)
|
||||||
.to_stream()
|
.to_stream()
|
||||||
.prefetch(4, 4)
|
.prefetch(4, 4)
|
||||||
.to_buffer()
|
|
||||||
)
|
)
|
||||||
return data_iter
|
return data_iter
|
||||||
|
|
||||||
@ -77,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
samples_per_sec = []
|
samples_per_sec = []
|
||||||
|
|
||||||
model.train(True)
|
model.train(True)
|
||||||
|
train_iter.reset()
|
||||||
for batch_counter, batch in enumerate(train_iter):
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
@ -112,6 +112,7 @@ def test_epoch(model, test_iter):
|
|||||||
model.train(False)
|
model.train(False)
|
||||||
accs = []
|
accs = []
|
||||||
throughput = []
|
throughput = []
|
||||||
|
test_iter.reset()
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
|
Loading…
Reference in New Issue
Block a user