mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-29 17:38:07 +08:00
a few examples
This commit is contained in:
18
mnist/README.md
Normal file
18
mnist/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# MNIST
|
||||
|
||||
This example shows how to run some simple models on MNIST. The only
|
||||
dependency is MLX.
|
||||
|
||||
Run the example with:
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
By default the example runs on the CPU. To run on the GPU, use:
|
||||
|
||||
```
|
||||
python main.py --gpu
|
||||
```
|
||||
|
||||
To run the PyTorch or Jax examples install the respective framework.
|
||||
80
mnist/jax_main.py
Normal file
80
mnist/jax_main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import functools
|
||||
import time
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
def init_model(key, num_layers, input_dim, hidden_dim, output_dim):
|
||||
params = []
|
||||
layer_sizes = [hidden_dim] * num_layers
|
||||
for idim, odim in zip([input_dim] + layer_sizes, layer_sizes + [output_dim]):
|
||||
key, wk = jax.random.split(key, 2)
|
||||
W = 1e-2 * jax.random.normal(wk, (idim, odim))
|
||||
b = jnp.zeros((odim,))
|
||||
params.append((W, b))
|
||||
return params
|
||||
|
||||
|
||||
def feed_forward(params, X):
|
||||
for W, b in params[:-1]:
|
||||
X = jnp.maximum(X @ W + b, 0)
|
||||
W, b = params[-1]
|
||||
return X @ W + b
|
||||
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
logits = feed_forward(params, X)
|
||||
logits = jax.nn.log_softmax(logits, 1)
|
||||
return -jnp.mean(logits[jnp.arange(y.size), y])
|
||||
|
||||
|
||||
@jax.jit
|
||||
def eval_fn(params, X, y):
|
||||
logits = feed_forward(params, X)
|
||||
return jnp.mean(jnp.argmax(logits, axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(key, batch_size, X, y):
|
||||
perm = jax.random.permutation(key, y.size)
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = mnist.mnist()
|
||||
|
||||
# Load the model
|
||||
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
||||
params = init_model(
|
||||
subkey, num_layers, train_images.shape[-1], hidden_dim, num_classes
|
||||
)
|
||||
|
||||
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
|
||||
update_fn = jax.jit(
|
||||
functools.partial(jax.tree_map, lambda p, g: p - learning_rate * g)
|
||||
)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
key, subkey = jax.random.split(key)
|
||||
for X, y in batch_iterate(subkey, batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(params, X, y)
|
||||
params = update_fn(params, grads)
|
||||
accuracy = eval_fn(params, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
||||
88
mnist/main.py
Normal file
88
mnist/main.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
def main():
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main()
|
||||
67
mnist/mnist.py
Normal file
67
mnist/mnist.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
|
||||
def mnist(save_dir="/tmp"):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
|
||||
Checks `save_dir` for already downloaded data otherwise downloads.
|
||||
|
||||
Download code modified from:
|
||||
https://github.com/hsjeong5/MNIST-for-Numpy
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
base_url = "http://yann.lecun.com/exdb/mnist/"
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
["training_labels", "train-labels-idx1-ubyte.gz"],
|
||||
["test_labels", "t10k-labels-idx1-ubyte.gz"],
|
||||
]
|
||||
|
||||
mnist = {}
|
||||
for name in filename:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
request.urlretrieve(base_url + name[1], out_file)
|
||||
for name in filename[:2]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
|
||||
-1, 28 * 28
|
||||
)
|
||||
for name in filename[-2:]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, "mnist.pkl")
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
preproc = lambda x: x.astype(np.float32) / 255.0
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
mnist["training_images"],
|
||||
mnist["training_labels"].astype(np.uint32),
|
||||
mnist["test_images"],
|
||||
mnist["test_labels"].astype(np.uint32),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
assert train_y.shape == (60000,), "Wrong training set size"
|
||||
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
|
||||
assert test_y.shape == (10000,), "Wrong test set size"
|
||||
88
mnist/torch_main.py
Normal file
88
mnist/torch_main.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
layer_sizes = [hidden_dim] * num_layers
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(idim, odim)
|
||||
for idim, odim in zip(
|
||||
[input_dim] + layer_sizes, layer_sizes + [output_dim]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers[0](x)
|
||||
for l in self.layers[1:]:
|
||||
x = l(x.relu())
|
||||
return x
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
logits = model(X)
|
||||
return torch.nn.functional.cross_entropy(logits, y)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_fn(model, X, y):
|
||||
logits = model(X)
|
||||
return torch.mean((logits.argmax(-1) == y).float())
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y, device):
|
||||
perm = torch.randperm(len(y), device=device)
|
||||
for s in range(0, len(y), batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.gpu:
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "mps"
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
def to_tensor(x):
|
||||
if x.dtype != "uint32":
|
||||
return torch.from_numpy(x).to(device)
|
||||
else:
|
||||
return torch.from_numpy(x.astype(int)).to(device)
|
||||
|
||||
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
||||
opt = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels, device):
|
||||
opt.zero_grad()
|
||||
loss_fn(model, X, y).backward()
|
||||
opt.step()
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
||||
Reference in New Issue
Block a user