mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Updated CIFAR-10 ResNet example to use BatchNorm instead of LayerNorm (#257)
* replaced nn.LayerNorm by nn.BatchNorm * mlx>=0.0.8 required * updated default to 30 epochs instead of 100 * updated README after adding BatchNorm * requires mlx>=0.0.9 * updated README.md with results for mlx-0.0.9
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
"""
|
||||
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
||||
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
||||
|
||||
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
@@ -46,12 +44,12 @@ class Block(nn.Module):
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.LayerNorm(dims)
|
||||
self.bn1 = nn.BatchNorm(dims)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.LayerNorm(dims)
|
||||
self.bn2 = nn.BatchNorm(dims)
|
||||
|
||||
if stride != 1:
|
||||
self.shortcut = ShortcutA(dims)
|
||||
@@ -77,7 +75,7 @@ class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.LayerNorm(16)
|
||||
self.bn1 = nn.BatchNorm(16)
|
||||
|
||||
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||
|
||||
Reference in New Issue
Block a user