mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Adds 3D pooling (#1526)
This commit is contained in:
parent
61d787726a
commit
cb431dfc9f
@ -70,7 +70,14 @@ from mlx.nn.layers.normalization import (
|
|||||||
LayerNorm,
|
LayerNorm,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
)
|
)
|
||||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
from mlx.nn.layers.pooling import (
|
||||||
|
AvgPool1d,
|
||||||
|
AvgPool2d,
|
||||||
|
AvgPool3d,
|
||||||
|
MaxPool1d,
|
||||||
|
MaxPool2d,
|
||||||
|
MaxPool3d,
|
||||||
|
)
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
|
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
|
||||||
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||||
|
@ -158,6 +158,30 @@ class _Pool2d(_Pool):
|
|||||||
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||||
|
|
||||||
|
|
||||||
|
class _Pool3d(_Pool):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pooling_function,
|
||||||
|
padding_value,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
|
||||||
|
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
|
||||||
|
):
|
||||||
|
class_name = type(self).__name__
|
||||||
|
msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
|
||||||
|
kernel_size = _value_or_list(
|
||||||
|
kernel_size, 3, msg.format(class_name, "kernel_size")
|
||||||
|
)
|
||||||
|
if stride is not None:
|
||||||
|
stride = _value_or_list(stride, 3, msg.format(class_name, "stride"))
|
||||||
|
else:
|
||||||
|
stride = kernel_size
|
||||||
|
padding = _value_or_list(padding, 3, msg.format(class_name, "padding"))
|
||||||
|
padding = [(p, p) for p in padding]
|
||||||
|
|
||||||
|
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||||
|
|
||||||
|
|
||||||
class MaxPool1d(_Pool1d):
|
class MaxPool1d(_Pool1d):
|
||||||
r"""Applies 1-dimensional max pooling.
|
r"""Applies 1-dimensional max pooling.
|
||||||
|
|
||||||
@ -332,3 +356,104 @@ class AvgPool2d(_Pool2d):
|
|||||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||||
):
|
):
|
||||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool3d(_Pool3d):
|
||||||
|
"""
|
||||||
|
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
|
||||||
|
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
|
||||||
|
H_{out}, W_{out}, C)`, given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
|
||||||
|
& \text{input}(N_i, \text{stride[0]} \times d + l,
|
||||||
|
\text{stride[1]} \times h + m,
|
||||||
|
\text{stride[2]} \times w + n, C_j),
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||||
|
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
|
||||||
|
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
|
||||||
|
|
||||||
|
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||||
|
|
||||||
|
- a single ``int`` -- in which case the same value is used for the depth,
|
||||||
|
height and width axis;
|
||||||
|
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
|
||||||
|
for the depth axis, the second ``int`` for the height axis, and the third
|
||||||
|
``int`` for the width axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
|
||||||
|
stride (int or tuple(int, int, int), optional): The stride of the pooling
|
||||||
|
window. Default: ``kernel_size``.
|
||||||
|
padding (int or tuple(int, int, int), optional): How much negative infinity
|
||||||
|
padding to apply to the input. The padding is applied on both sides
|
||||||
|
of the depth, height and width axis. Default: ``0``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn.layers as nn
|
||||||
|
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||||
|
>>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
|
||||||
|
>>> pool(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
|
||||||
|
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
|
||||||
|
):
|
||||||
|
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool3d(_Pool3d):
|
||||||
|
"""
|
||||||
|
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
|
||||||
|
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
|
||||||
|
H_{out}, W_{out}, C)`, given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
|
||||||
|
& \text{input}(N_i, \text{stride[0]} \times d + l,
|
||||||
|
\text{stride[1]} \times h + m,
|
||||||
|
\text{stride[2]} \times w + n, C_j),
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||||
|
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
|
||||||
|
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
|
||||||
|
|
||||||
|
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||||
|
|
||||||
|
- a single ``int`` -- in which case the same value is used for the depth,
|
||||||
|
height and width axis;
|
||||||
|
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
|
||||||
|
for the depth axis, the second ``int`` for the height axis, and the third
|
||||||
|
``int`` for the width axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
|
||||||
|
stride (int or tuple(int, int, int), optional): The stride of the pooling
|
||||||
|
window. Default: ``kernel_size``.
|
||||||
|
padding (int or tuple(int, int, int), optional): How much zero
|
||||||
|
padding to apply to the input. The padding is applied on both sides
|
||||||
|
of the depth, height and width axis. Default: ``0``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn.layers as nn
|
||||||
|
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||||
|
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
|
||||||
|
>>> pool(x)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
|
||||||
|
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
|
||||||
|
):
|
||||||
|
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||||
|
@ -1589,6 +1589,123 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
|
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
|
||||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||||
)
|
)
|
||||||
|
# Test 3d pooling
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
|
||||||
|
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
|
||||||
|
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
|
||||||
|
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
|
||||||
|
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected_max_pool_output_no_padding_stride_1 = [
|
||||||
|
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
|
||||||
|
expected_max_pool_output_padding_1 = [
|
||||||
|
[
|
||||||
|
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
|
||||||
|
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
expected_irregular_max_pool_output = [
|
||||||
|
[
|
||||||
|
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
|
||||||
|
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
|
||||||
|
expected_max_pool_output_no_padding_stride_1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
|
||||||
|
expected_max_pool_output_no_padding_stride_2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
|
||||||
|
expected_max_pool_output_padding_1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
|
||||||
|
expected_irregular_max_pool_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
|
||||||
|
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
|
||||||
|
[22.5, 23.5, 24.5]],
|
||||||
|
[[28.5, 29.5, 30.5],
|
||||||
|
[31.5, 32.5, 33.5]]]]
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
|
||||||
|
expected_avg_pool_output_padding_1 = [
|
||||||
|
[[[[0, 0.125, 0.25],
|
||||||
|
[1.125, 1.375, 1.625]],
|
||||||
|
[[3.375, 3.625, 3.875],
|
||||||
|
[9, 9.5, 10]]],
|
||||||
|
[[[3.375, 3.5, 3.625],
|
||||||
|
[7.875, 8.125, 8.375]],
|
||||||
|
[[10.125, 10.375, 10.625],
|
||||||
|
[22.5, 23, 23.5]]]]
|
||||||
|
]
|
||||||
|
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
|
||||||
|
[7.5, 8.5, 9.5],
|
||||||
|
[10.5, 11.5, 12.5]]],
|
||||||
|
[[[31.5, 32.5, 33.5],
|
||||||
|
[34.5, 35.5, 36.5],
|
||||||
|
[37.5, 38.5, 39.5]]]]
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
|
||||||
|
expected_avg_pool_output_no_padding_stride_1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
|
||||||
|
expected_avg_pool_output_no_padding_stride_2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
|
||||||
|
expected_avg_pool_output_padding_1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(
|
||||||
|
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
|
||||||
|
expected_irregular_avg_pool_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
|
||||||
|
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||||
|
)
|
||||||
|
|
||||||
def test_set_dtype(self):
|
def test_set_dtype(self):
|
||||||
def assert_dtype(layer, dtype):
|
def assert_dtype(layer, dtype):
|
||||||
|
Loading…
Reference in New Issue
Block a user