mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 17:58:54 +08:00
@@ -49,6 +49,13 @@ def batch_iterate(batch_size, X, y, device):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="mnist",
|
||||
choices=["mnist", "fashion_mnist"],
|
||||
help="The dataset to use.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.gpu:
|
||||
@@ -71,7 +78,9 @@ if __name__ == "__main__":
|
||||
else:
|
||||
return torch.from_numpy(x.astype(int)).to(device)
|
||||
|
||||
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
to_tensor, getattr(mnist, args.dataset)()
|
||||
)
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
||||
|
||||
Reference in New Issue
Block a user