| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from mlx.data.datasets import load_mnist | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def mnist(batch_size, img_size, root=None): | 
					
						
							|  |  |  |     # load train and test sets using mlx-data | 
					
						
							|  |  |  |     load_fn = load_mnist | 
					
						
							|  |  |  |     tr = load_fn(root=root, train=True) | 
					
						
							|  |  |  |     test = load_fn(root=root, train=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # number of image channels is 1 for MNIST | 
					
						
							|  |  |  |     num_img_channels = 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # normalize to [0,1] | 
					
						
							|  |  |  |     def normalize(x): | 
					
						
							|  |  |  |         return x.astype("float32") / 255.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # iterator over training set | 
					
						
							|  |  |  |     tr_iter = ( | 
					
						
							|  |  |  |         tr.shuffle() | 
					
						
							|  |  |  |         .to_stream() | 
					
						
							|  |  |  |         .image_resize("image", h=img_size[0], w=img_size[1]) | 
					
						
							|  |  |  |         .key_transform("image", normalize) | 
					
						
							|  |  |  |         .batch(batch_size) | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |         .prefetch(4, 4) | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # iterator over test set | 
					
						
							|  |  |  |     test_iter = ( | 
					
						
							|  |  |  |         test.to_stream() | 
					
						
							|  |  |  |         .image_resize("image", h=img_size[0], w=img_size[1]) | 
					
						
							|  |  |  |         .key_transform("image", normalize) | 
					
						
							|  |  |  |         .batch(batch_size) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return tr_iter, test_iter | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     batch_size = 32 | 
					
						
							|  |  |  |     img_size = (64, 64)  # (H, W) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     B, H, W, C = batch_size, img_size[0], img_size[1], 1 | 
					
						
							|  |  |  |     print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     batch_tr_iter = next(tr_iter) | 
					
						
							|  |  |  |     assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size" | 
					
						
							|  |  |  |     assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     batch_test_iter = next(test_iter) | 
					
						
							|  |  |  |     assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size" | 
					
						
							|  |  |  |     assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size" |