mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
@@ -54,10 +54,10 @@ if __name__ == "__main__":
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
dataset = "mnist"
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = mnist.mnist()
|
||||
|
||||
train_images, train_labels, test_images, test_labels = getattr(mnist, dataset)()
|
||||
# Load the model
|
||||
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
||||
params = init_model(
|
||||
|
||||
Reference in New Issue
Block a user