mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
simplified ResNet, expanded README with throughput and performance
This commit is contained in:
@@ -38,7 +38,6 @@ class ShortcutA(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
expansion = 1
|
||||
"""
|
||||
Implements a ResNet block with two convolutional layers and a skip connection.
|
||||
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
||||
@@ -57,7 +56,7 @@ class Block(nn.Module):
|
||||
)
|
||||
self.bn2 = nn.LayerNorm(dims)
|
||||
|
||||
if stride != 1 or in_dims != dims:
|
||||
if stride != 1:
|
||||
self.shortcut = ShortcutA(dims)
|
||||
else:
|
||||
self.shortcut = None
|
||||
@@ -83,20 +82,19 @@ class ResNet(nn.Module):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.LayerNorm(16)
|
||||
self.in_dims = 16
|
||||
|
||||
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
|
||||
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
|
||||
|
||||
self.linear = nn.Linear(64, num_classes)
|
||||
|
||||
def _make_layer(self, block, dims, num_blocks, stride):
|
||||
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_dims, dims, stride))
|
||||
self.in_dims = dims * block.expansion
|
||||
layers.append(block(in_dims, dims, stride))
|
||||
in_dims = dims
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def num_params(self):
|
||||
|
Reference in New Issue
Block a user