fashion-mnist example (#180)

* fashion mnist example

* fix from review
This commit is contained in:
Kashif Rasul
2023-12-23 16:34:45 +01:00
committed by GitHub
parent 848f118ac5
commit 0371d90ccb
5 changed files with 46 additions and 11 deletions

View File

@@ -45,7 +45,7 @@ def batch_iterate(batch_size, X, y):
yield X[ids], y[ids]
def main():
def main(args):
seed = 0
num_layers = 2
hidden_dim = 32
@@ -57,7 +57,9 @@ def main():
np.random.seed(seed)
# Load the data
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
train_images, train_labels, test_images, test_labels = map(
mx.array, getattr(mnist, args.dataset)()
)
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
@@ -83,7 +85,14 @@ def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main()
main(args)