Adds 3D pooling (#1526)

This commit is contained in:
Saanidhya
2024-11-19 19:45:24 -05:00
committed by GitHub
parent 61d787726a
commit cb431dfc9f
3 changed files with 250 additions and 1 deletions

View File

@@ -70,7 +70,14 @@ from mlx.nn.layers.normalization import (
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.pooling import (
AvgPool1d,
AvgPool2d,
AvgPool3d,
MaxPool1d,
MaxPool2d,
MaxPool3d,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN

View File

@@ -158,6 +158,30 @@ class _Pool2d(_Pool):
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class _Pool3d(_Pool):
def __init__(
self,
pooling_function,
padding_value,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
class_name = type(self).__name__
msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
kernel_size = _value_or_list(
kernel_size, 3, msg.format(class_name, "kernel_size")
)
if stride is not None:
stride = _value_or_list(stride, 3, msg.format(class_name, "stride"))
else:
stride = kernel_size
padding = _value_or_list(padding, 3, msg.format(class_name, "padding"))
padding = [(p, p) for p in padding]
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class MaxPool1d(_Pool1d):
r"""Applies 1-dimensional max pooling.
@@ -332,3 +356,104 @@ class AvgPool2d(_Pool2d):
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
class MaxPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much negative infinity
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
class AvgPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much zero
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)