simplified ResNet, expanded README with throughput and performance

This commit is contained in:
Sarthak Yadav
2023-12-14 09:05:04 +01:00
parent 2439333a57
commit 15a6c155a8
5 changed files with 56 additions and 36 deletions

View File

@@ -38,7 +38,6 @@ class ShortcutA(nn.Module):
class Block(nn.Module):
expansion = 1
"""
Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
@@ -57,7 +56,7 @@ class Block(nn.Module):
)
self.bn2 = nn.LayerNorm(dims)
if stride != 1 or in_dims != dims:
if stride != 1:
self.shortcut = ShortcutA(dims)
else:
self.shortcut = None
@@ -83,20 +82,19 @@ class ResNet(nn.Module):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.LayerNorm(16)
self.in_dims = 16
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
def _make_layer(self, block, dims, num_blocks, stride):
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_dims, dims, stride))
self.in_dims = dims * block.expansion
layers.append(block(in_dims, dims, stride))
in_dims = dims
return nn.Sequential(*layers)
def num_params(self):