Adds 3D pooling (#1526)

This commit is contained in:
Saanidhya 2024-11-19 19:45:24 -05:00 committed by GitHub
parent 61d787726a
commit cb431dfc9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 250 additions and 1 deletions

View File

@ -70,7 +70,14 @@ from mlx.nn.layers.normalization import (
LayerNorm, LayerNorm,
RMSNorm, RMSNorm,
) )
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.pooling import (
AvgPool1d,
AvgPool2d,
AvgPool3d,
MaxPool1d,
MaxPool2d,
MaxPool3d,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.recurrent import GRU, LSTM, RNN

View File

@ -158,6 +158,30 @@ class _Pool2d(_Pool):
super().__init__(pooling_function, kernel_size, stride, padding, padding_value) super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class _Pool3d(_Pool):
def __init__(
self,
pooling_function,
padding_value,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
class_name = type(self).__name__
msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
kernel_size = _value_or_list(
kernel_size, 3, msg.format(class_name, "kernel_size")
)
if stride is not None:
stride = _value_or_list(stride, 3, msg.format(class_name, "stride"))
else:
stride = kernel_size
padding = _value_or_list(padding, 3, msg.format(class_name, "padding"))
padding = [(p, p) for p in padding]
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class MaxPool1d(_Pool1d): class MaxPool1d(_Pool1d):
r"""Applies 1-dimensional max pooling. r"""Applies 1-dimensional max pooling.
@ -332,3 +356,104 @@ class AvgPool2d(_Pool2d):
padding: Optional[Union[int, Tuple[int, int]]] = 0, padding: Optional[Union[int, Tuple[int, int]]] = 0,
): ):
super().__init__(mx.mean, 0, kernel_size, stride, padding) super().__init__(mx.mean, 0, kernel_size, stride, padding)
class MaxPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much negative infinity
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
class AvgPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much zero
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@ -1589,6 +1589,123 @@ class TestLayers(mlx_tests.MLXTestCase):
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
) )
# Test 3d pooling
x = mx.array(
[
[
[
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
],
[
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
],
]
]
)
expected_max_pool_output_no_padding_stride_1 = [
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
]
expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
expected_max_pool_output_padding_1 = [
[
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
]
]
expected_irregular_max_pool_output = [
[
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
]
]
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_max_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_max_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_max_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_max_pool_output,
)
)
self.assertEqual(
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
[22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5],
[31.5, 32.5, 33.5]]]]
]
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
expected_avg_pool_output_padding_1 = [
[[[[0, 0.125, 0.25],
[1.125, 1.375, 1.625]],
[[3.375, 3.625, 3.875],
[9, 9.5, 10]]],
[[[3.375, 3.5, 3.625],
[7.875, 8.125, 8.375]],
[[10.125, 10.375, 10.625],
[22.5, 23, 23.5]]]]
]
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5],
[10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5],
[34.5, 35.5, 36.5],
[37.5, 38.5, 39.5]]]]
]
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_avg_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_avg_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_avg_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_avg_pool_output,
)
)
self.assertEqual(
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
def test_set_dtype(self): def test_set_dtype(self):
def assert_dtype(layer, dtype): def assert_dtype(layer, dtype):