* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
This commit is contained in:
Max-Heinrich Laves
2024-05-11 15:15:02 +02:00
committed by GitHub
parent a9f80d60f6
commit ff4223904d
10 changed files with 951 additions and 13 deletions

View File

@@ -48,7 +48,7 @@ from mlx.nn.layers.activations import (
)
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@@ -132,3 +132,66 @@ class Conv2d(Module):
if "bias" in self:
y = y + self.bias
return y
class Conv3d(Module):
"""Applies a 3-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
- ``N`` is the batch dimension
- ``D`` is the input image depth
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv3d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y