diff --git a/cifar/README.md b/cifar/README.md new file mode 100644 index 00000000..d6bdaf9a --- /dev/null +++ b/cifar/README.md @@ -0,0 +1,51 @@ +# CIFAR and ResNets + +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet +configurations in accordance with the original +[paper](https://arxiv.org/abs/1512.03385) are available. The example also +illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to +load the dataset. + +## Pre-requisites + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +## Running the example + +Run the example with: + +``` +python main.py +``` + +By default the example runs on the GPU. To run on the CPU, use: + +``` +python main.py --cpu +``` + +For all available options, run: + +``` +python main.py --help +``` + +## Results + +After training with the default `resnet20` architecture for 100 epochs, you +should see the following results: + +``` +Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec +Epoch: 99 | Test acc 0.807 +``` + +Note this was run on an M1 Macbook Pro with 16GB RAM. + +At the time of writing, `mlx` doesn't have built-in learning rate schedules, +or a `BatchNorm` layer. We intend to update this example once these features +are added. diff --git a/cifar/dataset.py b/cifar/dataset.py new file mode 100644 index 00000000..89b10136 --- /dev/null +++ b/cifar/dataset.py @@ -0,0 +1,30 @@ +import mlx.core as mx +from mlx.data.datasets import load_cifar10 +import math + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root) + + mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) + std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + + def normalize(x): + x = x.astype("float32") / 255.0 + return (x - mean) / std + + tr_iter = ( + tr.shuffle() + .to_stream() + .image_random_h_flip("image", prob=0.5) + .pad("image", 0, 4, 4, 0.0) + .pad("image", 1, 4, 4, 0.0) + .image_random_crop("image", 32, 32) + .key_transform("image", normalize) + .batch(batch_size) + ) + + test = load_cifar10(root=root, train=False) + test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + + return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py new file mode 100644 index 00000000..829417b1 --- /dev/null +++ b/cifar/main.py @@ -0,0 +1,120 @@ +import argparse +import time +import resnet +import mlx.nn as nn +import mlx.core as mx +import mlx.optimizers as optim +from dataset import get_cifar10 + + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument( + "--arch", + type=str, + default="resnet20", + choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]], + help="model architecture", +) +parser.add_argument("--batch_size", type=int, default=256, help="batch size") +parser.add_argument("--epochs", type=int, default=100, help="number of epochs") +parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") +parser.add_argument("--seed", type=int, default=0, help="random seed") +parser.add_argument("--cpu", action="store_true", help="use cpu only") + + +def eval_fn(model, inp, tgt): + return mx.mean(mx.argmax(model(inp), axis=1) == tgt) + + +def train_epoch(model, train_iter, optimizer, epoch): + def train_step(model, inp, tgt): + output = model(inp) + loss = mx.mean(nn.losses.cross_entropy(output, tgt)) + acc = mx.mean(mx.argmax(output, axis=1) == tgt) + return loss, acc + + train_step_fn = nn.value_and_grad(model, train_step) + + losses = [] + accs = [] + samples_per_sec = [] + + for batch_counter, batch in enumerate(train_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + tic = time.perf_counter() + (loss, acc), grads = train_step_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + toc = time.perf_counter() + loss = loss.item() + acc = acc.item() + losses.append(loss) + accs.append(acc) + throughput = x.shape[0] / (toc - tic) + samples_per_sec.append(throughput) + if batch_counter % 10 == 0: + print( + " | ".join( + ( + f"Epoch {epoch:02d} [{batch_counter:03d}]", + f"Train loss {loss:.3f}", + f"Train acc {acc:.3f}", + f"Throughput: {throughput:.2f} images/second", + ) + ) + ) + + mean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_acc = mx.mean(mx.array(accs)) + samples_per_sec = mx.mean(mx.array(samples_per_sec)) + return mean_tr_loss, mean_tr_acc, samples_per_sec + + +def test_epoch(model, test_iter, epoch): + accs = [] + for batch_counter, batch in enumerate(test_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + acc = eval_fn(model, x, y) + acc_value = acc.item() + accs.append(acc_value) + mean_acc = mx.mean(mx.array(accs)) + return mean_acc + + +def main(args): + mx.random.seed(args.seed) + + model = getattr(resnet, args.arch)() + + print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + + optimizer = optim.Adam(learning_rate=args.lr) + + train_data, test_data = get_cifar10(args.batch_size) + for epoch in range(args.epochs): + tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) + print( + " | ".join( + ( + f"Epoch: {epoch}", + f"avg. Train loss {tr_loss.item():.3f}", + f"avg. Train acc {tr_acc.item():.3f}", + f"Throughput: {throughput.item():.2f} images/sec", + ) + ) + ) + + test_acc = test_epoch(model, test_data, epoch) + print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + + train_data.reset() + test_data.reset() + + +if __name__ == "__main__": + args = parser.parse_args() + if args.cpu: + mx.set_default_device(mx.cpu) + main(args) diff --git a/cifar/requirements.txt b/cifar/requirements.txt new file mode 100644 index 00000000..6ff78a64 --- /dev/null +++ b/cifar/requirements.txt @@ -0,0 +1,2 @@ +mlx +mlx-data \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py new file mode 100644 index 00000000..758ee3de --- /dev/null +++ b/cifar/resnet.py @@ -0,0 +1,131 @@ +""" +Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385]. +Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. + +There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. +""" + +from typing import Any +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + + +__all__ = [ + "ResNet", + "resnet20", + "resnet32", + "resnet44", + "resnet56", + "resnet110", + "resnet1202", +] + + +class ShortcutA(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def __call__(self, x): + return mx.pad( + x[:, ::2, ::2, :], + pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)], + ) + + +class Block(nn.Module): + """ + Implements a ResNet block with two convolutional layers and a skip connection. + As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) + """ + + def __init__(self, in_dims, dims, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d( + in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.LayerNorm(dims) + + self.conv2 = nn.Conv2d( + dims, dims, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.LayerNorm(dims) + + if stride != 1: + self.shortcut = ShortcutA(dims) + else: + self.shortcut = None + + def __call__(self, x): + out = nn.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + if self.shortcut is None: + out += x + else: + out += self.shortcut(x) + out = nn.relu(out) + return out + + +class ResNet(nn.Module): + """ + Creates a ResNet model for CIFAR-10, as specified in the original paper. + """ + + def __init__(self, block, num_blocks, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.LayerNorm(16) + + self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2) + + self.linear = nn.Linear(64, num_classes) + + def _make_layer(self, block, in_dims, dims, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(in_dims, dims, stride)) + in_dims = dims + return nn.Sequential(*layers) + + def num_params(self): + nparams = sum(x.size for k, x in tree_flatten(self.parameters())) + return nparams + + def __call__(self, x): + x = nn.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1) + x = self.linear(x) + return x + + +def resnet20(**kwargs): + return ResNet(Block, [3, 3, 3], **kwargs) + + +def resnet32(**kwargs): + return ResNet(Block, [5, 5, 5], **kwargs) + + +def resnet44(**kwargs): + return ResNet(Block, [7, 7, 7], **kwargs) + + +def resnet56(**kwargs): + return ResNet(Block, [9, 9, 9], **kwargs) + + +def resnet110(**kwargs): + return ResNet(Block, [18, 18, 18], **kwargs) + + +def resnet1202(**kwargs): + return ResNet(Block, [200, 200, 200], **kwargs)