mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:

committed by
GitHub

parent
a9f80d60f6
commit
ff4223904d
@@ -48,7 +48,7 @@ from mlx.nn.layers.activations import (
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
|
@@ -132,3 +132,66 @@ class Conv2d(Module):
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class Conv3d(Module):
|
||||
"""Applies a 3-dimensional convolution over the multi-channel input image.
|
||||
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``D`` is the input image depth
|
||||
- ``H`` is the input image height
|
||||
- ``W`` is the input image width
|
||||
- ``C`` is the number of input channels
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: ``1``.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(
|
||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv3d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
Reference in New Issue
Block a user