mlx-examples/cvae/dataset.py
Markus Enzweiler 9b387007ab
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit

* style fixes

* update of ACKNOWLEDGMENTS

* fixed comment

* minor refactoring; removed unused imports

* added cifar and cvae to top-level README.md

* removed mention of cuda/mps in argparse

* fixed training status output

* load_weights() with strict=True

* pretrained model update

* fixed imports and style

* requires mlx>=0.0.9

* updated with results using mlx 0.0.9

* removed mention of private repo

* simplify and combine to one file, more consistency with other exmaples

* few more nits

* nits

* spell

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-06 20:02:27 -08:00

54 lines
1.6 KiB
Python

# Copyright © 2023-2024 Apple Inc.
from mlx.data.datasets import load_mnist
def mnist(batch_size, img_size, root=None):
# load train and test sets using mlx-data
load_fn = load_mnist
tr = load_fn(root=root, train=True)
test = load_fn(root=root, train=False)
# number of image channels is 1 for MNIST
num_img_channels = 1
# normalize to [0,1]
def normalize(x):
return x.astype("float32") / 255.0
# iterator over training set
tr_iter = (
tr.shuffle()
.to_stream()
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
)
# iterator over test set
test_iter = (
test.to_stream()
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter
if __name__ == "__main__":
batch_size = 32
img_size = (64, 64) # (H, W)
tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
B, H, W, C = batch_size, img_size[0], img_size[1], 1
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
batch_tr_iter = next(tr_iter)
assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
batch_test_iter = next(test_iter)
assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"