mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
fixed doc for ResNet
This commit is contained in:
parent
f37e777243
commit
2439333a57
@ -39,6 +39,10 @@ class ShortcutA(nn.Module):
|
|||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
"""
|
||||||
|
Implements a ResNet block with two convolutional layers and a skip connection.
|
||||||
|
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dims, dims, stride=1):
|
def __init__(self, in_dims, dims, stride=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -71,6 +75,10 @@ class Block(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
Creates a ResNet model for CIFAR-10, as specified in the original paper.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, block, num_blocks, num_classes=10):
|
def __init__(self, block, num_blocks, num_classes=10):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user