fashion-mnist example (#180)

* fashion mnist example

* fix from review
This commit is contained in:
Kashif Rasul
2023-12-23 16:34:45 +01:00
committed by GitHub
parent 848f118ac5
commit 0371d90ccb
5 changed files with 46 additions and 11 deletions

View File

@@ -49,6 +49,13 @@ def batch_iterate(batch_size, X, y, device):
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
@@ -71,7 +78,9 @@ if __name__ == "__main__":
else:
return torch.from_numpy(x.astype(int)).to(device)
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
train_images, train_labels, test_images, test_labels = map(
to_tensor, getattr(mnist, args.dataset)()
)
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)