mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
@@ -8,7 +8,9 @@ from urllib import request
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mnist(save_dir="/tmp"):
|
||||
def mnist(
|
||||
save_dir="/tmp", base_url="http://yann.lecun.com/exdb/mnist/", filename="mnist.pkl"
|
||||
):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
@@ -20,7 +22,6 @@ def mnist(save_dir="/tmp"):
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
base_url = "http://yann.lecun.com/exdb/mnist/"
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
@@ -45,13 +46,15 @@ def mnist(save_dir="/tmp"):
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, "mnist.pkl")
|
||||
save_file = os.path.join(save_dir, filename)
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
preproc = lambda x: x.astype(np.float32) / 255.0
|
||||
def preproc(x):
|
||||
return x.astype(np.float32) / 255.0
|
||||
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
@@ -62,6 +65,14 @@ def mnist(save_dir="/tmp"):
|
||||
)
|
||||
|
||||
|
||||
def fashion_mnist(save_dir="/tmp"):
|
||||
return mnist(
|
||||
save_dir,
|
||||
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
|
||||
filename="fashion_mnist.pkl",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
|
Reference in New Issue
Block a user