diff --git a/cifar/README.md b/cifar/README.md index 0d793853..abb2c0f5 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -1,11 +1,10 @@ # CIFAR and ResNets -* This example shows how to run ResNets on CIFAR10 dataset, in accordance with the original [paper](https://arxiv.org/abs/1512.03385). -* Also illustrates how to use `mlx-data` to download and load the dataset. +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset. ## Pre-requisites -* Install the dependencies: +Install the dependencies: ``` pip install -r requirements.txt @@ -21,7 +20,7 @@ python main.py By default the example runs on the GPU. To run on the CPU, use: ``` -python main.py --cpu_only +python main.py --cpu ``` For all available options, run: @@ -29,3 +28,24 @@ For all available options, run: ``` python main.py --help ``` + + +## Throughput + +On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`: +``` +Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec +``` + +When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!): +``` +Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec +``` + +## Results +After training for 100 epochs, the following results were observed: +``` +Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec +Epoch: 99 | test_acc 0.807 +``` +At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added. \ No newline at end of file diff --git a/cifar/dataset.py b/cifar/dataset.py index f4a3cd63..29f558d1 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -36,4 +36,4 @@ def get_cifar10(batch_size, root=None): num_tr_steps_per_epoch = num_tr_samples // batch_size num_test_steps_per_epoch = num_test_samples // batch_size - return tr_iter, test_iter, num_tr_steps_per_epoch, num_test_steps_per_epoch + return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 5272733a..29b0cbc7 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -1,6 +1,6 @@ import argparse +import time import resnet -import numpy as np import mlx.nn as nn import mlx.core as mx import mlx.optimizers as optim @@ -14,11 +14,11 @@ parser.add_argument( default="resnet20", help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", ) -parser.add_argument("--batch_size", type=int, default=128, help="batch size") +parser.add_argument("--batch_size", type=int, default=256, help="batch size") parser.add_argument("--epochs", type=int, default=100, help="number of epochs") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") parser.add_argument("--seed", type=int, default=0, help="random seed") -parser.add_argument("--cpu_only", action="store_true", help="use cpu only") +parser.add_argument("--cpu", action="store_true", help="use cpu only") def loss_fn(model, inp, tgt): @@ -40,27 +40,30 @@ def train_epoch(model, train_iter, optimizer, epoch): losses = [] accs = [] + samples_per_sec = [] for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) + tic = time.perf_counter() (loss, acc), grads = train_step_fn(model, x, y) optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) - + toc = time.perf_counter() loss_value = loss.item() acc_value = acc.item() losses.append(loss_value) accs.append(acc_value) - + samples_per_sec.append(x.shape[0] / (toc - tic)) if batch_counter % 10 == 0: print( - f"Epoch {epoch:02d}[{batch_counter:03d}]: tr_loss {loss_value:.3f}, tr_acc {acc_value:.3f}" + f"Epoch {epoch:02d} [{batch_counter:03d}] | tr_loss {loss_value:.3f} | tr_acc {acc_value:.3f} | Throughput: {x.shape[0] / (toc - tic):.2f} images/second" ) - mean_tr_loss = np.mean(np.array(losses)) - mean_tr_acc = np.mean(np.array(accs)) - return mean_tr_loss, mean_tr_acc + mean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_acc = mx.mean(mx.array(accs)) + samples_per_sec = mx.mean(mx.array(samples_per_sec)) + return mean_tr_loss, mean_tr_acc, samples_per_sec def test_epoch(model, test_iter, epoch): @@ -71,13 +74,11 @@ def test_epoch(model, test_iter, epoch): acc = eval_fn(model, x, y) acc_value = acc.item() accs.append(acc_value) - mean_acc = np.mean(np.array(accs)) - + mean_acc = mx.mean(mx.array(accs)) return mean_acc def main(args): - np.random.seed(args.seed) mx.random.seed(args.seed) model = resnet.__dict__[args.arch]() @@ -87,22 +88,24 @@ def main(args): optimizer = optim.Adam(learning_rate=args.lr) + train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): - # get data every epoch - # or set .repeat() on the data stream appropriately - train_data, test_data, tr_batches, _ = get_cifar10(args.batch_size) - - epoch_tr_loss, epoch_tr_acc = train_epoch(model, train_data, optimizer, epoch) + epoch_tr_loss, epoch_tr_acc, train_throughput = train_epoch( + model, train_data, optimizer, epoch + ) print( - f"Epoch {epoch}: avg. tr_loss {epoch_tr_loss:.3f}, avg. tr_acc {epoch_tr_acc:.3f}" + f"Epoch: {epoch} | avg. tr_loss {epoch_tr_loss.item():.3f} | avg. tr_acc {epoch_tr_acc.item():.3f} | Train Throughput: {train_throughput.item():.2f} images/sec" ) epoch_test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch {epoch}: Test_acc {epoch_test_acc:.3f}") + print(f"Epoch: {epoch} | test_acc {epoch_test_acc.item():.3f}") + + train_data.reset() + test_data.reset() if __name__ == "__main__": args = parser.parse_args() - if args.cpu_only: + if args.cpu: mx.set_default_device(mx.cpu) main(args) diff --git a/cifar/requirements.txt b/cifar/requirements.txt index c4c2e575..6ff78a64 100644 --- a/cifar/requirements.txt +++ b/cifar/requirements.txt @@ -1,3 +1,2 @@ mlx -mlx-data -numpy \ No newline at end of file +mlx-data \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py index b89a612b..6eeadda6 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -38,7 +38,6 @@ class ShortcutA(nn.Module): class Block(nn.Module): - expansion = 1 """ Implements a ResNet block with two convolutional layers and a skip connection. As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) @@ -57,7 +56,7 @@ class Block(nn.Module): ) self.bn2 = nn.LayerNorm(dims) - if stride != 1 or in_dims != dims: + if stride != 1: self.shortcut = ShortcutA(dims) else: self.shortcut = None @@ -83,20 +82,19 @@ class ResNet(nn.Module): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.LayerNorm(16) - self.in_dims = 16 - self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2) self.linear = nn.Linear(64, num_classes) - def _make_layer(self, block, dims, num_blocks, stride): + def _make_layer(self, block, in_dims, dims, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_dims, dims, stride)) - self.in_dims = dims * block.expansion + layers.append(block(in_dims, dims, stride)) + in_dims = dims return nn.Sequential(*layers) def num_params(self):