diff --git a/cifar/README.md b/cifar/README.md index d6bdaf9a..763e641d 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -36,16 +36,15 @@ python main.py --help ## Results -After training with the default `resnet20` architecture for 100 epochs, you +After training with the default `resnet20` architecture for 30 epochs, you should see the following results: ``` -Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec -Epoch: 99 | Test acc 0.807 +Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec +Epoch: 29 | Test acc 0.841 ``` Note this was run on an M1 Macbook Pro with 16GB RAM. -At the time of writing, `mlx` doesn't have built-in learning rate schedules, -or a `BatchNorm` layer. We intend to update this example once these features -are added. +At the time of writing, `mlx` doesn't have built-in learning rate schedules. +We intend to update this example once these features are added. diff --git a/cifar/main.py b/cifar/main.py index 42ed57d3..f2174772 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -16,7 +16,7 @@ parser.add_argument( help="model architecture", ) parser.add_argument("--batch_size", type=int, default=256, help="batch size") -parser.add_argument("--epochs", type=int, default=100, help="number of epochs") +parser.add_argument("--epochs", type=int, default=30, help="number of epochs") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") diff --git a/cifar/requirements.txt b/cifar/requirements.txt index c4c2e575..e4764d13 100644 --- a/cifar/requirements.txt +++ b/cifar/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.0.9 mlx-data numpy \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py index e707574f..04300a36 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -1,8 +1,6 @@ """ Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385]. Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. - -There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. """ from typing import Any @@ -46,12 +44,12 @@ class Block(nn.Module): self.conv1 = nn.Conv2d( in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False ) - self.bn1 = nn.LayerNorm(dims) + self.bn1 = nn.BatchNorm(dims) self.conv2 = nn.Conv2d( dims, dims, kernel_size=3, stride=1, padding=1, bias=False ) - self.bn2 = nn.LayerNorm(dims) + self.bn2 = nn.BatchNorm(dims) if stride != 1: self.shortcut = ShortcutA(dims) @@ -77,7 +75,7 @@ class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.LayerNorm(16) + self.bn1 = nn.BatchNorm(16) self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)