mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-03 13:08:08 +08:00
a few examples
This commit is contained in:
67
mnist/mnist.py
Normal file
67
mnist/mnist.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
|
||||
def mnist(save_dir="/tmp"):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
|
||||
Checks `save_dir` for already downloaded data otherwise downloads.
|
||||
|
||||
Download code modified from:
|
||||
https://github.com/hsjeong5/MNIST-for-Numpy
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
base_url = "http://yann.lecun.com/exdb/mnist/"
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
["training_labels", "train-labels-idx1-ubyte.gz"],
|
||||
["test_labels", "t10k-labels-idx1-ubyte.gz"],
|
||||
]
|
||||
|
||||
mnist = {}
|
||||
for name in filename:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
request.urlretrieve(base_url + name[1], out_file)
|
||||
for name in filename[:2]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
|
||||
-1, 28 * 28
|
||||
)
|
||||
for name in filename[-2:]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, "mnist.pkl")
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
preproc = lambda x: x.astype(np.float32) / 255.0
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
mnist["training_images"],
|
||||
mnist["training_labels"].astype(np.uint32),
|
||||
mnist["test_images"],
|
||||
mnist["test_labels"].astype(np.uint32),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
assert train_y.shape == (60000,), "Wrong training set size"
|
||||
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
|
||||
assert test_y.shape == (10000,), "Wrong test set size"
|
||||
Reference in New Issue
Block a user