mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit * style fixes * update of ACKNOWLEDGMENTS * fixed comment * minor refactoring; removed unused imports * added cifar and cvae to top-level README.md * removed mention of cuda/mps in argparse * fixed training status output * load_weights() with strict=True * pretrained model update * fixed imports and style * requires mlx>=0.0.9 * updated with results using mlx 0.0.9 * removed mention of private repo * simplify and combine to one file, more consistency with other exmaples * few more nits * nits * spell * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
53
cvae/dataset.py
Normal file
53
cvae/dataset.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from mlx.data.datasets import load_mnist
|
||||
|
||||
|
||||
def mnist(batch_size, img_size, root=None):
|
||||
# load train and test sets using mlx-data
|
||||
load_fn = load_mnist
|
||||
tr = load_fn(root=root, train=True)
|
||||
test = load_fn(root=root, train=False)
|
||||
|
||||
# number of image channels is 1 for MNIST
|
||||
num_img_channels = 1
|
||||
|
||||
# normalize to [0,1]
|
||||
def normalize(x):
|
||||
return x.astype("float32") / 255.0
|
||||
|
||||
# iterator over training set
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.to_stream()
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
# iterator over test set
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
return tr_iter, test_iter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 32
|
||||
img_size = (64, 64) # (H, W)
|
||||
|
||||
tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
|
||||
|
||||
B, H, W, C = batch_size, img_size[0], img_size[1], 1
|
||||
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
|
||||
|
||||
batch_tr_iter = next(tr_iter)
|
||||
assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||
assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
|
||||
|
||||
batch_test_iter = next(test_iter)
|
||||
assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||
assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"
|
||||
Reference in New Issue
Block a user