mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	 f45a1ab83c
			
		
	
	f45a1ab83c
	
	
	
		
			
			* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
		
			
				
	
	
		
			55 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			55 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| from mlx.data.datasets import load_mnist
 | |
| 
 | |
| 
 | |
| def mnist(batch_size, img_size, root=None):
 | |
|     # load train and test sets using mlx-data
 | |
|     load_fn = load_mnist
 | |
|     tr = load_fn(root=root, train=True)
 | |
|     test = load_fn(root=root, train=False)
 | |
| 
 | |
|     # number of image channels is 1 for MNIST
 | |
|     num_img_channels = 1
 | |
| 
 | |
|     # normalize to [0,1]
 | |
|     def normalize(x):
 | |
|         return x.astype("float32") / 255.0
 | |
| 
 | |
|     # iterator over training set
 | |
|     tr_iter = (
 | |
|         tr.shuffle()
 | |
|         .to_stream()
 | |
|         .image_resize("image", h=img_size[0], w=img_size[1])
 | |
|         .key_transform("image", normalize)
 | |
|         .batch(batch_size)
 | |
|         .prefetch(4, 4)
 | |
|     )
 | |
| 
 | |
|     # iterator over test set
 | |
|     test_iter = (
 | |
|         test.to_stream()
 | |
|         .image_resize("image", h=img_size[0], w=img_size[1])
 | |
|         .key_transform("image", normalize)
 | |
|         .batch(batch_size)
 | |
|     )
 | |
|     return tr_iter, test_iter
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     batch_size = 32
 | |
|     img_size = (64, 64)  # (H, W)
 | |
| 
 | |
|     tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
 | |
| 
 | |
|     B, H, W, C = batch_size, img_size[0], img_size[1], 1
 | |
|     print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
 | |
| 
 | |
|     batch_tr_iter = next(tr_iter)
 | |
|     assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
 | |
|     assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
 | |
| 
 | |
|     batch_test_iter = next(test_iter)
 | |
|     assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
 | |
|     assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"
 |