fashion-mnist example (#180)

* fashion mnist example

* fix from review
This commit is contained in:
Kashif Rasul
2023-12-23 16:34:45 +01:00
committed by GitHub
parent 848f118ac5
commit 0371d90ccb
5 changed files with 46 additions and 11 deletions

View File

@@ -54,10 +54,10 @@ if __name__ == "__main__":
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
dataset = "mnist"
# Load the data
train_images, train_labels, test_images, test_labels = mnist.mnist()
train_images, train_labels, test_images, test_labels = getattr(mnist, dataset)()
# Load the model
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
params = init_model(