| 
									
										
										
										
											2023-11-30 11:08:53 -08:00
										 |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | import argparse | 
					
						
							|  |  |  | import time | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  | from functools import partial | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							|  |  |  | import mlx.nn as nn | 
					
						
							|  |  |  | import mlx.optimizers as optim | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import mnist | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MLP(nn.Module): | 
					
						
							|  |  |  |     """A simple MLP.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] | 
					
						
							|  |  |  |         self.layers = [ | 
					
						
							|  |  |  |             nn.Linear(idim, odim) | 
					
						
							|  |  |  |             for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, x): | 
					
						
							|  |  |  |         for l in self.layers[:-1]: | 
					
						
							| 
									
										
										
										
											2024-01-17 05:42:56 +08:00
										 |  |  |             x = nn.relu(l(x)) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |         return self.layers[-1](x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def loss_fn(model, X, y): | 
					
						
							| 
									
										
										
										
											2024-01-17 05:42:56 +08:00
										 |  |  |     return nn.losses.cross_entropy(model(X), y, reduction="mean") | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def batch_iterate(batch_size, X, y): | 
					
						
							|  |  |  |     perm = mx.array(np.random.permutation(y.size)) | 
					
						
							|  |  |  |     for s in range(0, y.size, batch_size): | 
					
						
							|  |  |  |         ids = perm[s : s + batch_size] | 
					
						
							|  |  |  |         yield X[ids], y[ids] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-23 16:34:45 +01:00
										 |  |  | def main(args): | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |     seed = 0 | 
					
						
							|  |  |  |     num_layers = 2 | 
					
						
							|  |  |  |     hidden_dim = 32 | 
					
						
							|  |  |  |     num_classes = 10 | 
					
						
							|  |  |  |     batch_size = 256 | 
					
						
							|  |  |  |     num_epochs = 10 | 
					
						
							|  |  |  |     learning_rate = 1e-1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     np.random.seed(seed) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Load the data | 
					
						
							| 
									
										
										
										
											2023-12-23 16:34:45 +01:00
										 |  |  |     train_images, train_labels, test_images, test_labels = map( | 
					
						
							|  |  |  |         mx.array, getattr(mnist, args.dataset)() | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Load the model | 
					
						
							|  |  |  |     model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) | 
					
						
							|  |  |  |     mx.eval(model.parameters()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     optimizer = optim.SGD(learning_rate=learning_rate) | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |     loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @partial(mx.compile, inputs=model.state, outputs=model.state) | 
					
						
							|  |  |  |     def step(X, y): | 
					
						
							|  |  |  |         loss, grads = loss_and_grad_fn(model, X, y) | 
					
						
							|  |  |  |         optimizer.update(model, grads) | 
					
						
							|  |  |  |         return loss | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @partial(mx.compile, inputs=model.state) | 
					
						
							|  |  |  |     def eval_fn(X, y): | 
					
						
							|  |  |  |         return mx.mean(mx.argmax(model(X), axis=1) == y) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for e in range(num_epochs): | 
					
						
							|  |  |  |         tic = time.perf_counter() | 
					
						
							|  |  |  |         for X, y in batch_iterate(batch_size, train_images, train_labels): | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |             step(X, y) | 
					
						
							|  |  |  |             mx.eval(model.state) | 
					
						
							|  |  |  |         accuracy = eval_fn(test_images, test_labels) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |         toc = time.perf_counter() | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             f"Epoch {e}: Test accuracy {accuracy.item():.3f}," | 
					
						
							|  |  |  |             f" Time {toc - tic:.3f} (s)" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") | 
					
						
							|  |  |  |     parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") | 
					
						
							| 
									
										
										
										
											2023-12-23 16:34:45 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--dataset", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="mnist", | 
					
						
							|  |  |  |         choices=["mnist", "fashion_mnist"], | 
					
						
							|  |  |  |         help="The dataset to use.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |     if not args.gpu: | 
					
						
							|  |  |  |         mx.set_default_device(mx.cpu) | 
					
						
							| 
									
										
										
										
											2023-12-23 16:34:45 +01:00
										 |  |  |     main(args) |