simplified ResNet, expanded README with throughput and performance

This commit is contained in:
Sarthak Yadav 2023-12-14 09:05:04 +01:00
parent 2439333a57
commit 15a6c155a8
5 changed files with 56 additions and 36 deletions

View File

@ -1,11 +1,10 @@
# CIFAR and ResNets # CIFAR and ResNets
* This example shows how to run ResNets on CIFAR10 dataset, in accordance with the original [paper](https://arxiv.org/abs/1512.03385). An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset.
* Also illustrates how to use `mlx-data` to download and load the dataset.
## Pre-requisites ## Pre-requisites
* Install the dependencies: Install the dependencies:
``` ```
pip install -r requirements.txt pip install -r requirements.txt
@ -21,7 +20,7 @@ python main.py
By default the example runs on the GPU. To run on the CPU, use: By default the example runs on the GPU. To run on the CPU, use:
``` ```
python main.py --cpu_only python main.py --cpu
``` ```
For all available options, run: For all available options, run:
@ -29,3 +28,24 @@ For all available options, run:
``` ```
python main.py --help python main.py --help
``` ```
## Throughput
On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`:
```
Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec
```
When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!):
```
Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec
```
## Results
After training for 100 epochs, the following results were observed:
```
Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec
Epoch: 99 | test_acc 0.807
```
At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added.

View File

@ -36,4 +36,4 @@ def get_cifar10(batch_size, root=None):
num_tr_steps_per_epoch = num_tr_samples // batch_size num_tr_steps_per_epoch = num_tr_samples // batch_size
num_test_steps_per_epoch = num_test_samples // batch_size num_test_steps_per_epoch = num_test_samples // batch_size
return tr_iter, test_iter, num_tr_steps_per_epoch, num_test_steps_per_epoch return tr_iter, test_iter

View File

@ -1,6 +1,6 @@
import argparse import argparse
import time
import resnet import resnet
import numpy as np
import mlx.nn as nn import mlx.nn as nn
import mlx.core as mx import mlx.core as mx
import mlx.optimizers as optim import mlx.optimizers as optim
@ -14,11 +14,11 @@ parser.add_argument(
default="resnet20", default="resnet20",
help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]",
) )
parser.add_argument("--batch_size", type=int, default=128, help="batch size") parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs") parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu_only", action="store_true", help="use cpu only") parser.add_argument("--cpu", action="store_true", help="use cpu only")
def loss_fn(model, inp, tgt): def loss_fn(model, inp, tgt):
@ -40,27 +40,30 @@ def train_epoch(model, train_iter, optimizer, epoch):
losses = [] losses = []
accs = [] accs = []
samples_per_sec = []
for batch_counter, batch in enumerate(train_iter): for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["image"]) x = mx.array(batch["image"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y) (loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads) optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter()
loss_value = loss.item() loss_value = loss.item()
acc_value = acc.item() acc_value = acc.item()
losses.append(loss_value) losses.append(loss_value)
accs.append(acc_value) accs.append(acc_value)
samples_per_sec.append(x.shape[0] / (toc - tic))
if batch_counter % 10 == 0: if batch_counter % 10 == 0:
print( print(
f"Epoch {epoch:02d}[{batch_counter:03d}]: tr_loss {loss_value:.3f}, tr_acc {acc_value:.3f}" f"Epoch {epoch:02d} [{batch_counter:03d}] | tr_loss {loss_value:.3f} | tr_acc {acc_value:.3f} | Throughput: {x.shape[0] / (toc - tic):.2f} images/second"
) )
mean_tr_loss = np.mean(np.array(losses)) mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = np.mean(np.array(accs)) mean_tr_acc = mx.mean(mx.array(accs))
return mean_tr_loss, mean_tr_acc samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter, epoch): def test_epoch(model, test_iter, epoch):
@ -71,13 +74,11 @@ def test_epoch(model, test_iter, epoch):
acc = eval_fn(model, x, y) acc = eval_fn(model, x, y)
acc_value = acc.item() acc_value = acc.item()
accs.append(acc_value) accs.append(acc_value)
mean_acc = np.mean(np.array(accs)) mean_acc = mx.mean(mx.array(accs))
return mean_acc return mean_acc
def main(args): def main(args):
np.random.seed(args.seed)
mx.random.seed(args.seed) mx.random.seed(args.seed)
model = resnet.__dict__[args.arch]() model = resnet.__dict__[args.arch]()
@ -87,22 +88,24 @@ def main(args):
optimizer = optim.Adam(learning_rate=args.lr) optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs): for epoch in range(args.epochs):
# get data every epoch epoch_tr_loss, epoch_tr_acc, train_throughput = train_epoch(
# or set .repeat() on the data stream appropriately model, train_data, optimizer, epoch
train_data, test_data, tr_batches, _ = get_cifar10(args.batch_size) )
epoch_tr_loss, epoch_tr_acc = train_epoch(model, train_data, optimizer, epoch)
print( print(
f"Epoch {epoch}: avg. tr_loss {epoch_tr_loss:.3f}, avg. tr_acc {epoch_tr_acc:.3f}" f"Epoch: {epoch} | avg. tr_loss {epoch_tr_loss.item():.3f} | avg. tr_acc {epoch_tr_acc.item():.3f} | Train Throughput: {train_throughput.item():.2f} images/sec"
) )
epoch_test_acc = test_epoch(model, test_data, epoch) epoch_test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch {epoch}: Test_acc {epoch_test_acc:.3f}") print(f"Epoch: {epoch} | test_acc {epoch_test_acc.item():.3f}")
train_data.reset()
test_data.reset()
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.cpu_only: if args.cpu:
mx.set_default_device(mx.cpu) mx.set_default_device(mx.cpu)
main(args) main(args)

View File

@ -1,3 +1,2 @@
mlx mlx
mlx-data mlx-data
numpy

View File

@ -38,7 +38,6 @@ class ShortcutA(nn.Module):
class Block(nn.Module): class Block(nn.Module):
expansion = 1
""" """
Implements a ResNet block with two convolutional layers and a skip connection. Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
@ -57,7 +56,7 @@ class Block(nn.Module):
) )
self.bn2 = nn.LayerNorm(dims) self.bn2 = nn.LayerNorm(dims)
if stride != 1 or in_dims != dims: if stride != 1:
self.shortcut = ShortcutA(dims) self.shortcut = ShortcutA(dims)
else: else:
self.shortcut = None self.shortcut = None
@ -83,20 +82,19 @@ class ResNet(nn.Module):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.LayerNorm(16) self.bn1 = nn.LayerNorm(16)
self.in_dims = 16
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes) self.linear = nn.Linear(64, num_classes)
def _make_layer(self, block, dims, num_blocks, stride): def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1) strides = [stride] + [1] * (num_blocks - 1)
layers = [] layers = []
for stride in strides: for stride in strides:
layers.append(block(self.in_dims, dims, stride)) layers.append(block(in_dims, dims, stride))
self.in_dims = dims * block.expansion in_dims = dims
return nn.Sequential(*layers) return nn.Sequential(*layers)
def num_params(self): def num_params(self):