mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Updated CIFAR-10 ResNet example to use BatchNorm instead of LayerNorm (#257)
* replaced nn.LayerNorm by nn.BatchNorm * mlx>=0.0.8 required * updated default to 30 epochs instead of 100 * updated README after adding BatchNorm * requires mlx>=0.0.9 * updated README.md with results for mlx-0.0.9
This commit is contained in:
parent
6217d7acd0
commit
2b61d9deb6
@ -36,16 +36,15 @@ python main.py --help
|
|||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|
||||||
After training with the default `resnet20` architecture for 100 epochs, you
|
After training with the default `resnet20` architecture for 30 epochs, you
|
||||||
should see the following results:
|
should see the following results:
|
||||||
|
|
||||||
```
|
```
|
||||||
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
|
Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec
|
||||||
Epoch: 99 | Test acc 0.807
|
Epoch: 29 | Test acc 0.841
|
||||||
```
|
```
|
||||||
|
|
||||||
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||||
or a `BatchNorm` layer. We intend to update this example once these features
|
We intend to update this example once these features are added.
|
||||||
are added.
|
|
||||||
|
@ -16,7 +16,7 @@ parser.add_argument(
|
|||||||
help="model architecture",
|
help="model architecture",
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||||
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
parser.add_argument("--epochs", type=int, default=30, help="number of epochs")
|
||||||
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mlx
|
mlx>=0.0.9
|
||||||
mlx-data
|
mlx-data
|
||||||
numpy
|
numpy
|
@ -1,8 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
||||||
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
||||||
|
|
||||||
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -46,12 +44,12 @@ class Block(nn.Module):
|
|||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||||
)
|
)
|
||||||
self.bn1 = nn.LayerNorm(dims)
|
self.bn1 = nn.BatchNorm(dims)
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
)
|
)
|
||||||
self.bn2 = nn.LayerNorm(dims)
|
self.bn2 = nn.BatchNorm(dims)
|
||||||
|
|
||||||
if stride != 1:
|
if stride != 1:
|
||||||
self.shortcut = ShortcutA(dims)
|
self.shortcut = ShortcutA(dims)
|
||||||
@ -77,7 +75,7 @@ class ResNet(nn.Module):
|
|||||||
def __init__(self, block, num_blocks, num_classes=10):
|
def __init__(self, block, num_blocks, num_classes=10):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
self.bn1 = nn.LayerNorm(16)
|
self.bn1 = nn.BatchNorm(16)
|
||||||
|
|
||||||
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||||
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||||
|
Loading…
Reference in New Issue
Block a user