mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			83 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import jax
 | |
| import jax.numpy as jnp
 | |
| import functools
 | |
| import time
 | |
| 
 | |
| import mnist
 | |
| 
 | |
| 
 | |
| def init_model(key, num_layers, input_dim, hidden_dim, output_dim):
 | |
|     params = []
 | |
|     layer_sizes = [hidden_dim] * num_layers
 | |
|     for idim, odim in zip([input_dim] + layer_sizes, layer_sizes + [output_dim]):
 | |
|         key, wk = jax.random.split(key, 2)
 | |
|         W = 1e-2 * jax.random.normal(wk, (idim, odim))
 | |
|         b = jnp.zeros((odim,))
 | |
|         params.append((W, b))
 | |
|     return params
 | |
| 
 | |
| 
 | |
| def feed_forward(params, X):
 | |
|     for W, b in params[:-1]:
 | |
|         X = jnp.maximum(X @ W + b, 0)
 | |
|     W, b = params[-1]
 | |
|     return X @ W + b
 | |
| 
 | |
| 
 | |
| def loss_fn(params, X, y):
 | |
|     logits = feed_forward(params, X)
 | |
|     logits = jax.nn.log_softmax(logits, 1)
 | |
|     return -jnp.mean(logits[jnp.arange(y.size), y])
 | |
| 
 | |
| 
 | |
| @jax.jit
 | |
| def eval_fn(params, X, y):
 | |
|     logits = feed_forward(params, X)
 | |
|     return jnp.mean(jnp.argmax(logits, axis=1) == y)
 | |
| 
 | |
| 
 | |
| def batch_iterate(key, batch_size, X, y):
 | |
|     perm = jax.random.permutation(key, y.size)
 | |
|     for s in range(0, y.size, batch_size):
 | |
|         ids = perm[s : s + batch_size]
 | |
|         yield X[ids], y[ids]
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     seed = 0
 | |
|     num_layers = 2
 | |
|     hidden_dim = 32
 | |
|     num_classes = 10
 | |
|     batch_size = 256
 | |
|     num_epochs = 10
 | |
|     learning_rate = 1e-1
 | |
| 
 | |
|     # Load the data
 | |
|     train_images, train_labels, test_images, test_labels = mnist.mnist()
 | |
| 
 | |
|     # Load the model
 | |
|     key, subkey = jax.random.split(jax.random.PRNGKey(seed))
 | |
|     params = init_model(
 | |
|         subkey, num_layers, train_images.shape[-1], hidden_dim, num_classes
 | |
|     )
 | |
| 
 | |
|     loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
 | |
|     update_fn = jax.jit(
 | |
|         functools.partial(jax.tree_map, lambda p, g: p - learning_rate * g)
 | |
|     )
 | |
| 
 | |
|     for e in range(num_epochs):
 | |
|         tic = time.perf_counter()
 | |
|         key, subkey = jax.random.split(key)
 | |
|         for X, y in batch_iterate(subkey, batch_size, train_images, train_labels):
 | |
|             loss, grads = loss_and_grad_fn(params, X, y)
 | |
|             params = update_fn(params, grads)
 | |
|         accuracy = eval_fn(params, test_images, test_labels)
 | |
|         toc = time.perf_counter()
 | |
|         print(
 | |
|             f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
 | |
|             f" Time {toc - tic:.3f} (s)"
 | |
|         )
 | 
