mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
@@ -45,7 +45,7 @@ def batch_iterate(batch_size, X, y):
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
def main():
|
||||
def main(args):
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
@@ -57,7 +57,9 @@ def main():
|
||||
np.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, getattr(mnist, args.dataset)()
|
||||
)
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
@@ -83,7 +85,14 @@ def main():
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="mnist",
|
||||
choices=["mnist", "fashion_mnist"],
|
||||
help="The dataset to use.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main()
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user