mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

- bert/model.py:10: tree_unflatten - bert/model.py:2: dataclass - bert/model.py:8: numpy - cifar/resnet.py:6: Any - clip/model.py:15: tree_flatten - clip/model.py:9: Union - gcn/main.py:8: download_cora - gcn/main.py:9: cross_entropy - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten - llms/gguf_llm/models.py:9: numpy - llms/mixtral/mixtral.py:12: tree_map - llms/mlx_lm/models/dbrx.py:2: Dict, Union - llms/mlx_lm/tuner/trainer.py:5: partial - llms/speculative_decoding/decoder.py:1: dataclass, field - llms/speculative_decoding/decoder.py:2: Optional - llms/speculative_decoding/decoder.py:5: mlx.nn - llms/speculative_decoding/decoder.py:6: numpy - llms/speculative_decoding/main.py:2: glob - llms/speculative_decoding/main.py:3: json - llms/speculative_decoding/main.py:5: Path - llms/speculative_decoding/main.py:8: mlx.nn - llms/speculative_decoding/model.py:6: tree_unflatten - llms/speculative_decoding/model.py:7: AutoTokenizer - llms/tests/test_lora.py:13: yaml_loader - lora/lora.py:14: tree_unflatten - lora/models.py:11: numpy - lora/models.py:3: glob - speechcommands/kwt.py:1: Any - speechcommands/main.py:7: mlx.data - stable_diffusion/stable_diffusion/model_io.py:4: partial - whisper/benchmark.py:5: sys - whisper/test.py:5: subprocess - whisper/whisper/audio.py:6: Optional - whisper/whisper/decoding.py:8: mlx.nn
128 lines
3.3 KiB
Python
128 lines
3.3 KiB
Python
"""
|
|
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
|
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
|
"""
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_flatten
|
|
|
|
__all__ = [
|
|
"ResNet",
|
|
"resnet20",
|
|
"resnet32",
|
|
"resnet44",
|
|
"resnet56",
|
|
"resnet110",
|
|
"resnet1202",
|
|
]
|
|
|
|
|
|
class ShortcutA(nn.Module):
|
|
def __init__(self, dims):
|
|
super().__init__()
|
|
self.dims = dims
|
|
|
|
def __call__(self, x):
|
|
return mx.pad(
|
|
x[:, ::2, ::2, :],
|
|
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
|
|
)
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""
|
|
Implements a ResNet block with two convolutional layers and a skip connection.
|
|
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
|
"""
|
|
|
|
def __init__(self, in_dims, dims, stride=1):
|
|
super().__init__()
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
|
)
|
|
self.bn1 = nn.BatchNorm(dims)
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
|
)
|
|
self.bn2 = nn.BatchNorm(dims)
|
|
|
|
if stride != 1:
|
|
self.shortcut = ShortcutA(dims)
|
|
else:
|
|
self.shortcut = None
|
|
|
|
def __call__(self, x):
|
|
out = nn.relu(self.bn1(self.conv1(x)))
|
|
out = self.bn2(self.conv2(out))
|
|
if self.shortcut is None:
|
|
out += x
|
|
else:
|
|
out += self.shortcut(x)
|
|
out = nn.relu(out)
|
|
return out
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
"""
|
|
Creates a ResNet model for CIFAR-10, as specified in the original paper.
|
|
"""
|
|
|
|
def __init__(self, block, num_blocks, num_classes=10):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn1 = nn.BatchNorm(16)
|
|
|
|
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
|
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
|
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
|
|
|
|
self.linear = nn.Linear(64, num_classes)
|
|
|
|
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
layers = []
|
|
for stride in strides:
|
|
layers.append(block(in_dims, dims, stride))
|
|
in_dims = dims
|
|
return nn.Sequential(*layers)
|
|
|
|
def num_params(self):
|
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
|
return nparams
|
|
|
|
def __call__(self, x):
|
|
x = nn.relu(self.bn1(self.conv1(x)))
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
def resnet20(**kwargs):
|
|
return ResNet(Block, [3, 3, 3], **kwargs)
|
|
|
|
|
|
def resnet32(**kwargs):
|
|
return ResNet(Block, [5, 5, 5], **kwargs)
|
|
|
|
|
|
def resnet44(**kwargs):
|
|
return ResNet(Block, [7, 7, 7], **kwargs)
|
|
|
|
|
|
def resnet56(**kwargs):
|
|
return ResNet(Block, [9, 9, 9], **kwargs)
|
|
|
|
|
|
def resnet110(**kwargs):
|
|
return ResNet(Block, [18, 18, 18], **kwargs)
|
|
|
|
|
|
def resnet1202(**kwargs):
|
|
return ResNet(Block, [200, 200, 200], **kwargs)
|