mlx-examples/cifar/resnet.py
dmdaksh 7d7e236061
- Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00

128 lines
3.3 KiB
Python

"""
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = [
"ResNet",
"resnet20",
"resnet32",
"resnet44",
"resnet56",
"resnet110",
"resnet1202",
]
class ShortcutA(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def __call__(self, x):
return mx.pad(
x[:, ::2, ::2, :],
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
)
class Block(nn.Module):
"""
Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
"""
def __init__(self, in_dims, dims, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm(dims)
self.conv2 = nn.Conv2d(
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm(dims)
if stride != 1:
self.shortcut = ShortcutA(dims)
else:
self.shortcut = None
def __call__(self, x):
out = nn.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.shortcut is None:
out += x
else:
out += self.shortcut(x)
out = nn.relu(out)
return out
class ResNet(nn.Module):
"""
Creates a ResNet model for CIFAR-10, as specified in the original paper.
"""
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm(16)
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(in_dims, dims, stride))
in_dims = dims
return nn.Sequential(*layers)
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
x = nn.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
x = self.linear(x)
return x
def resnet20(**kwargs):
return ResNet(Block, [3, 3, 3], **kwargs)
def resnet32(**kwargs):
return ResNet(Block, [5, 5, 5], **kwargs)
def resnet44(**kwargs):
return ResNet(Block, [7, 7, 7], **kwargs)
def resnet56(**kwargs):
return ResNet(Block, [9, 9, 9], **kwargs)
def resnet110(**kwargs):
return ResNet(Block, [18, 18, 18], **kwargs)
def resnet1202(**kwargs):
return ResNet(Block, [200, 200, 200], **kwargs)