From 0371d90ccb88c51f3565f7e922d174792658203b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Dec 2023 16:34:45 +0100 Subject: [PATCH] fashion-mnist example (#180) * fashion mnist example * fix from review --- mnist/README.md | 8 +++++++- mnist/jax_main.py | 4 ++-- mnist/main.py | 15 ++++++++++++--- mnist/mnist.py | 19 +++++++++++++++---- mnist/torch_main.py | 11 ++++++++++- 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/mnist/README.md b/mnist/README.md index d80c7f23..6d548ec3 100644 --- a/mnist/README.md +++ b/mnist/README.md @@ -14,10 +14,16 @@ Run the example with: python main.py ``` -By default the example runs on the CPU. To run on the GPU, use: +By default, the example runs on the CPU. To run on the GPU, use: ``` python main.py --gpu ``` +For a full list of options run: + +``` +python main.py --help +``` + To run the PyTorch or Jax examples install the respective framework. diff --git a/mnist/jax_main.py b/mnist/jax_main.py index 088a4993..9d1b30cf 100644 --- a/mnist/jax_main.py +++ b/mnist/jax_main.py @@ -54,10 +54,10 @@ if __name__ == "__main__": batch_size = 256 num_epochs = 10 learning_rate = 1e-1 + dataset = "mnist" # Load the data - train_images, train_labels, test_images, test_labels = mnist.mnist() - + train_images, train_labels, test_images, test_labels = getattr(mnist, dataset)() # Load the model key, subkey = jax.random.split(jax.random.PRNGKey(seed)) params = init_model( diff --git a/mnist/main.py b/mnist/main.py index b921261f..47092041 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -45,7 +45,7 @@ def batch_iterate(batch_size, X, y): yield X[ids], y[ids] -def main(): +def main(args): seed = 0 num_layers = 2 hidden_dim = 32 @@ -57,7 +57,9 @@ def main(): np.random.seed(seed) # Load the data - train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist()) + train_images, train_labels, test_images, test_labels = map( + mx.array, getattr(mnist, args.dataset)() + ) # Load the model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) @@ -83,7 +85,14 @@ def main(): if __name__ == "__main__": parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + parser.add_argument( + "--dataset", + type=str, + default="mnist", + choices=["mnist", "fashion_mnist"], + help="The dataset to use.", + ) args = parser.parse_args() if not args.gpu: mx.set_default_device(mx.cpu) - main() + main(args) diff --git a/mnist/mnist.py b/mnist/mnist.py index b2e5c510..1a74fdd7 100644 --- a/mnist/mnist.py +++ b/mnist/mnist.py @@ -8,7 +8,9 @@ from urllib import request import numpy as np -def mnist(save_dir="/tmp"): +def mnist( + save_dir="/tmp", base_url="http://yann.lecun.com/exdb/mnist/", filename="mnist.pkl" +): """ Load the MNIST dataset in 4 tensors: train images, train labels, test images, and test labels. @@ -20,7 +22,6 @@ def mnist(save_dir="/tmp"): """ def download_and_save(save_file): - base_url = "http://yann.lecun.com/exdb/mnist/" filename = [ ["training_images", "train-images-idx3-ubyte.gz"], ["test_images", "t10k-images-idx3-ubyte.gz"], @@ -45,13 +46,15 @@ def mnist(save_dir="/tmp"): with open(save_file, "wb") as f: pickle.dump(mnist, f) - save_file = os.path.join(save_dir, "mnist.pkl") + save_file = os.path.join(save_dir, filename) if not os.path.exists(save_file): download_and_save(save_file) with open(save_file, "rb") as f: mnist = pickle.load(f) - preproc = lambda x: x.astype(np.float32) / 255.0 + def preproc(x): + return x.astype(np.float32) / 255.0 + mnist["training_images"] = preproc(mnist["training_images"]) mnist["test_images"] = preproc(mnist["test_images"]) return ( @@ -62,6 +65,14 @@ def mnist(save_dir="/tmp"): ) +def fashion_mnist(save_dir="/tmp"): + return mnist( + save_dir, + base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + filename="fashion_mnist.pkl", + ) + + if __name__ == "__main__": train_x, train_y, test_x, test_y = mnist() assert train_x.shape == (60000, 28 * 28), "Wrong training set size" diff --git a/mnist/torch_main.py b/mnist/torch_main.py index a6c0380a..befde04d 100644 --- a/mnist/torch_main.py +++ b/mnist/torch_main.py @@ -49,6 +49,13 @@ def batch_iterate(batch_size, X, y, device): if __name__ == "__main__": parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + parser.add_argument( + "--dataset", + type=str, + default="mnist", + choices=["mnist", "fashion_mnist"], + help="The dataset to use.", + ) args = parser.parse_args() if not args.gpu: @@ -71,7 +78,9 @@ if __name__ == "__main__": else: return torch.from_numpy(x.astype(int)).to(device) - train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist()) + train_images, train_labels, test_images, test_labels = map( + to_tensor, getattr(mnist, args.dataset)() + ) # Load the model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)