fixed doc for ResNet

This commit is contained in:
Sarthak Yadav 2023-12-12 19:07:39 +01:00
parent f37e777243
commit 2439333a57

View File

@ -39,6 +39,10 @@ class ShortcutA(nn.Module):
class Block(nn.Module): class Block(nn.Module):
expansion = 1 expansion = 1
"""
Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
"""
def __init__(self, in_dims, dims, stride=1): def __init__(self, in_dims, dims, stride=1):
super().__init__() super().__init__()
@ -71,6 +75,10 @@ class Block(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
"""
Creates a ResNet model for CIFAR-10, as specified in the original paper.
"""
def __init__(self, block, num_blocks, num_classes=10): def __init__(self, block, num_blocks, num_classes=10):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)