mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import gzip
 | |
| import numpy as np
 | |
| import os
 | |
| import pickle
 | |
| from urllib import request
 | |
| 
 | |
| 
 | |
| def mnist(save_dir="/tmp"):
 | |
|     """
 | |
|     Load the MNIST dataset in 4 tensors: train images, train labels,
 | |
|     test images, and test labels.
 | |
| 
 | |
|     Checks `save_dir` for already downloaded data otherwise downloads.
 | |
| 
 | |
|     Download code modified from:
 | |
|       https://github.com/hsjeong5/MNIST-for-Numpy
 | |
|     """
 | |
| 
 | |
|     def download_and_save(save_file):
 | |
|         base_url = "http://yann.lecun.com/exdb/mnist/"
 | |
|         filename = [
 | |
|             ["training_images", "train-images-idx3-ubyte.gz"],
 | |
|             ["test_images", "t10k-images-idx3-ubyte.gz"],
 | |
|             ["training_labels", "train-labels-idx1-ubyte.gz"],
 | |
|             ["test_labels", "t10k-labels-idx1-ubyte.gz"],
 | |
|         ]
 | |
| 
 | |
|         mnist = {}
 | |
|         for name in filename:
 | |
|             out_file = os.path.join("/tmp", name[1])
 | |
|             request.urlretrieve(base_url + name[1], out_file)
 | |
|         for name in filename[:2]:
 | |
|             out_file = os.path.join("/tmp", name[1])
 | |
|             with gzip.open(out_file, "rb") as f:
 | |
|                 mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
 | |
|                     -1, 28 * 28
 | |
|                 )
 | |
|         for name in filename[-2:]:
 | |
|             out_file = os.path.join("/tmp", name[1])
 | |
|             with gzip.open(out_file, "rb") as f:
 | |
|                 mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
 | |
|         with open(save_file, "wb") as f:
 | |
|             pickle.dump(mnist, f)
 | |
| 
 | |
|     save_file = os.path.join(save_dir, "mnist.pkl")
 | |
|     if not os.path.exists(save_file):
 | |
|         download_and_save(save_file)
 | |
|     with open(save_file, "rb") as f:
 | |
|         mnist = pickle.load(f)
 | |
| 
 | |
|     preproc = lambda x: x.astype(np.float32) / 255.0
 | |
|     mnist["training_images"] = preproc(mnist["training_images"])
 | |
|     mnist["test_images"] = preproc(mnist["test_images"])
 | |
|     return (
 | |
|         mnist["training_images"],
 | |
|         mnist["training_labels"].astype(np.uint32),
 | |
|         mnist["test_images"],
 | |
|         mnist["test_labels"].astype(np.uint32),
 | |
|     )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     train_x, train_y, test_x, test_y = mnist()
 | |
|     assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
 | |
|     assert train_y.shape == (60000,), "Wrong training set size"
 | |
|     assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
 | |
|     assert test_y.shape == (10000,), "Wrong test set size"
 | 
