mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Merge pull request #77 from SarthakYadav/main
Added CIFAR-10 + ResNet example
This commit is contained in:
commit
09fff84a85
51
cifar/README.md
Normal file
51
cifar/README.md
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# CIFAR and ResNets
|
||||||
|
|
||||||
|
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
|
||||||
|
configurations in accordance with the original
|
||||||
|
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
|
||||||
|
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
||||||
|
load the dataset.
|
||||||
|
|
||||||
|
## Pre-requisites
|
||||||
|
|
||||||
|
Install the dependencies:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Run the example with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the example runs on the GPU. To run on the CPU, use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For all available options, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
After training with the default `resnet20` architecture for 100 epochs, you
|
||||||
|
should see the following results:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
|
||||||
|
Epoch: 99 | Test acc 0.807
|
||||||
|
```
|
||||||
|
|
||||||
|
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
|
||||||
|
or a `BatchNorm` layer. We intend to update this example once these features
|
||||||
|
are added.
|
30
cifar/dataset.py
Normal file
30
cifar/dataset.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from mlx.data.datasets import load_cifar10
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def get_cifar10(batch_size, root=None):
|
||||||
|
tr = load_cifar10(root=root)
|
||||||
|
|
||||||
|
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||||
|
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
x = x.astype("float32") / 255.0
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
tr_iter = (
|
||||||
|
tr.shuffle()
|
||||||
|
.to_stream()
|
||||||
|
.image_random_h_flip("image", prob=0.5)
|
||||||
|
.pad("image", 0, 4, 4, 0.0)
|
||||||
|
.pad("image", 1, 4, 4, 0.0)
|
||||||
|
.image_random_crop("image", 32, 32)
|
||||||
|
.key_transform("image", normalize)
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
test = load_cifar10(root=root, train=False)
|
||||||
|
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||||
|
|
||||||
|
return tr_iter, test_iter
|
120
cifar/main.py
Normal file
120
cifar/main.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import resnet
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from dataset import get_cifar10
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--arch",
|
||||||
|
type=str,
|
||||||
|
default="resnet20",
|
||||||
|
choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
|
||||||
|
help="model architecture",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_fn(model, inp, tgt):
|
||||||
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_iter, optimizer, epoch):
|
||||||
|
def train_step(model, inp, tgt):
|
||||||
|
output = model(inp)
|
||||||
|
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||||
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
accs = []
|
||||||
|
samples_per_sec = []
|
||||||
|
|
||||||
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
|
x = mx.array(batch["image"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
|
(loss, acc), grads = train_step_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
loss = loss.item()
|
||||||
|
acc = acc.item()
|
||||||
|
losses.append(loss)
|
||||||
|
accs.append(acc)
|
||||||
|
throughput = x.shape[0] / (toc - tic)
|
||||||
|
samples_per_sec.append(throughput)
|
||||||
|
if batch_counter % 10 == 0:
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
|
f"Train loss {loss:.3f}",
|
||||||
|
f"Train acc {acc:.3f}",
|
||||||
|
f"Throughput: {throughput:.2f} images/second",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
|
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||||
|
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch(model, test_iter, epoch):
|
||||||
|
accs = []
|
||||||
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
|
x = mx.array(batch["image"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
acc = eval_fn(model, x, y)
|
||||||
|
acc_value = acc.item()
|
||||||
|
accs.append(acc_value)
|
||||||
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
|
return mean_acc
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model = getattr(resnet, args.arch)()
|
||||||
|
|
||||||
|
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||||
|
|
||||||
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
|
|
||||||
|
train_data, test_data = get_cifar10(args.batch_size)
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch: {epoch}",
|
||||||
|
f"avg. Train loss {tr_loss.item():.3f}",
|
||||||
|
f"avg. Train acc {tr_acc.item():.3f}",
|
||||||
|
f"Throughput: {throughput.item():.2f} images/sec",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_acc = test_epoch(model, test_data, epoch)
|
||||||
|
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||||
|
|
||||||
|
train_data.reset()
|
||||||
|
test_data.reset()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
main(args)
|
2
cifar/requirements.txt
Normal file
2
cifar/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
mlx
|
||||||
|
mlx-data
|
131
cifar/resnet.py
Normal file
131
cifar/resnet.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
||||||
|
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
||||||
|
|
||||||
|
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ResNet",
|
||||||
|
"resnet20",
|
||||||
|
"resnet32",
|
||||||
|
"resnet44",
|
||||||
|
"resnet56",
|
||||||
|
"resnet110",
|
||||||
|
"resnet1202",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutA(nn.Module):
|
||||||
|
def __init__(self, dims):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return mx.pad(
|
||||||
|
x[:, ::2, ::2, :],
|
||||||
|
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a ResNet block with two convolutional layers and a skip connection.
|
||||||
|
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dims, dims, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn1 = nn.LayerNorm(dims)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn2 = nn.LayerNorm(dims)
|
||||||
|
|
||||||
|
if stride != 1:
|
||||||
|
self.shortcut = ShortcutA(dims)
|
||||||
|
else:
|
||||||
|
self.shortcut = None
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
out = nn.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bn2(self.conv2(out))
|
||||||
|
if self.shortcut is None:
|
||||||
|
out += x
|
||||||
|
else:
|
||||||
|
out += self.shortcut(x)
|
||||||
|
out = nn.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
Creates a ResNet model for CIFAR-10, as specified in the original paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block, num_blocks, num_classes=10):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.LayerNorm(16)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||||
|
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
|
||||||
|
|
||||||
|
self.linear = nn.Linear(64, num_classes)
|
||||||
|
|
||||||
|
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(in_dims, dims, stride))
|
||||||
|
in_dims = dims
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def num_params(self):
|
||||||
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
||||||
|
return nparams
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = nn.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def resnet20(**kwargs):
|
||||||
|
return ResNet(Block, [3, 3, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet32(**kwargs):
|
||||||
|
return ResNet(Block, [5, 5, 5], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet44(**kwargs):
|
||||||
|
return ResNet(Block, [7, 7, 7], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet56(**kwargs):
|
||||||
|
return ResNet(Block, [9, 9, 9], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet110(**kwargs):
|
||||||
|
return ResNet(Block, [18, 18, 18], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet1202(**kwargs):
|
||||||
|
return ResNet(Block, [200, 200, 200], **kwargs)
|
Loading…
Reference in New Issue
Block a user