diff --git a/cifar/resnet.py b/cifar/resnet.py index 3d88397b..b89a612b 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -39,6 +39,10 @@ class ShortcutA(nn.Module): class Block(nn.Module): expansion = 1 + """ + Implements a ResNet block with two convolutional layers and a skip connection. + As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) + """ def __init__(self, in_dims, dims, stride=1): super().__init__() @@ -71,6 +75,10 @@ class Block(nn.Module): class ResNet(nn.Module): + """ + Creates a ResNet model for CIFAR-10, as specified in the original paper. + """ + def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)