diff --git a/speechcommands/main.py b/speechcommands/main.py index 0d8da9fd..3f0de98c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -48,6 +48,7 @@ def prepare_dataset(batch_size, split, root=None): .batch(batch_size) .to_stream() .prefetch(4, 4) + .to_buffer() ) return data_iter