mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
40c108766b
commit
e54cbb7ba6
@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@ -10,6 +10,8 @@ Layers
|
||||
:template: nn-module-template.rst
|
||||
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
@ -22,6 +24,8 @@ Layers
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
Linear
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import (
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
|
308
python/mlx/nn/layers/pooling.py
Normal file
308
python/mlx/nn/layers/pooling.py
Normal file
@ -0,0 +1,308 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from itertools import accumulate
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _value_or_list(x, n, msg):
|
||||
if isinstance(x, (list, tuple)):
|
||||
if len(x) != n:
|
||||
raise ValueError(msg)
|
||||
return list(x)
|
||||
|
||||
if not isinstance(x, int):
|
||||
raise ValueError(msg)
|
||||
|
||||
return [x] * n
|
||||
|
||||
|
||||
def _sliding_windows(x, window_shape, window_strides):
|
||||
if x.ndim < 3:
|
||||
raise ValueError(
|
||||
f"To extract sliding windows at least 1 spatial dimension "
|
||||
f"(3 total) is needed but the input only has {x.ndim} dimensions."
|
||||
)
|
||||
|
||||
spatial_dims = x.shape[1:-1]
|
||||
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
|
||||
raise ValueError(
|
||||
f"To extract sliding windows the window shapes and strides must have "
|
||||
f"the same number of spatial dimensions as the signal but the signal "
|
||||
f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} "
|
||||
f"and strides have {len(window_strides)}."
|
||||
)
|
||||
|
||||
shape = x.shape
|
||||
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
|
||||
|
||||
# Compute the output shape
|
||||
final_shape = [shape[0]]
|
||||
final_shape += [
|
||||
(size - window) // stride + 1
|
||||
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
|
||||
]
|
||||
final_shape += window_shape
|
||||
final_shape += [shape[-1]]
|
||||
|
||||
# Compute the output strides
|
||||
final_strides = strides[:1]
|
||||
final_strides += [
|
||||
og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides)
|
||||
]
|
||||
final_strides += strides[1:-1]
|
||||
final_strides += strides[-1:] # should always be [1]
|
||||
|
||||
return mx.as_strided(x, final_shape, final_strides)
|
||||
|
||||
|
||||
class _Pool(Module):
|
||||
def __init__(self, pooling_function, kernel_size, stride, padding, padding_value):
|
||||
super().__init__()
|
||||
|
||||
self._pooling_function = pooling_function
|
||||
self._kernel_size = kernel_size
|
||||
self._stride = stride
|
||||
self._padding = padding
|
||||
self._padding_value = padding_value
|
||||
self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1))
|
||||
|
||||
def _extra_repr(self):
|
||||
ks = tuple(self._kernel_size)
|
||||
st = tuple(self._stride)
|
||||
pd = tuple(p[0] for p in self._padding)
|
||||
|
||||
return f"kernel_size={ks}, stride={st}, padding={pd}"
|
||||
|
||||
def __call__(self, x):
|
||||
if any(p[0] > 0 for p in self._padding):
|
||||
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
|
||||
x = _sliding_windows(x, self._kernel_size, self._stride)
|
||||
return self._pooling_function(x, self._axes)
|
||||
|
||||
|
||||
class _Pool1d(_Pool):
|
||||
def __init__(
|
||||
self,
|
||||
pooling_function,
|
||||
padding_value,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
class_name = type(self).__name__
|
||||
msg = "[{}] '{}' must be an integer or a tuple containing 1 integer"
|
||||
kernel_size = _value_or_list(
|
||||
kernel_size, 1, msg.format(class_name, "kernel_size")
|
||||
)
|
||||
if stride is not None:
|
||||
stride = _value_or_list(stride, 1, msg.format(class_name, "stride"))
|
||||
else:
|
||||
stride = kernel_size
|
||||
padding = _value_or_list(padding, 1, msg.format(class_name, "padding"))
|
||||
padding = [(p, p) for p in padding]
|
||||
|
||||
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||
|
||||
|
||||
class _Pool2d(_Pool):
|
||||
def __init__(
|
||||
self,
|
||||
pooling_function,
|
||||
padding_value,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
class_name = type(self).__name__
|
||||
msg = "[{}] '{}' must be an integer or a tuple containing 2 integers"
|
||||
kernel_size = _value_or_list(
|
||||
kernel_size, 2, msg.format(class_name, "kernel_size")
|
||||
)
|
||||
if stride is not None:
|
||||
stride = _value_or_list(stride, 2, msg.format(class_name, "stride"))
|
||||
else:
|
||||
stride = kernel_size
|
||||
padding = _value_or_list(padding, 2, msg.format(class_name, "padding"))
|
||||
padding = [(p, p) for p in padding]
|
||||
|
||||
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||
|
||||
|
||||
class MaxPool1d(_Pool1d):
|
||||
r"""Applies 1-dimensional max pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
|
||||
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
|
||||
by:
|
||||
|
||||
.. math::
|
||||
\text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
stride (int or tuple(int), optional): The stride of the pooling window.
|
||||
Default: ``kernel_size``.
|
||||
padding (int or tuple(int), optional): How much negative infinity
|
||||
padding to apply to the input. The padding amount is applied to
|
||||
both sides of the spatial axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(4, 16, 5))
|
||||
>>> pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AvgPool1d(_Pool1d):
|
||||
r"""Applies 1-dimensional average pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
|
||||
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
|
||||
by:
|
||||
|
||||
.. math::
|
||||
\text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
stride (int or tuple(int), optional): The stride of the pooling window.
|
||||
Default: ``kernel_size``.
|
||||
padding (int or tuple(int), optional): How much zero padding to apply to
|
||||
the input. The padding amount is applied to both sides of the spatial
|
||||
axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(4, 16, 5))
|
||||
>>> pool = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
class MaxPool2d(_Pool2d):
|
||||
r"""Applies 2-dimensional max pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
|
||||
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
|
||||
W_{out}, C)`, given by:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
|
||||
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for both the
|
||||
height and width axis;
|
||||
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height axis, the second ``int`` for the width axis.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int, int)): The size of the pooling window.
|
||||
stride (int or tuple(int, int), optional): The stride of the pooling
|
||||
window. Default: ``kernel_size``.
|
||||
padding (int or tuple(int, int), optional): How much negative infinity
|
||||
padding to apply to the input. The padding is applied on both sides
|
||||
of the height and width axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AvgPool2d(_Pool2d):
|
||||
r"""Applies 2-dimensional average pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
|
||||
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
|
||||
W_{out}, C)`, given by:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
|
||||
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for both the
|
||||
height and width axis;
|
||||
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height axis, the second ``int`` for the width axis.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int, int)): The size of the pooling window.
|
||||
stride (int or tuple(int, int), optional): The stride of the pooling
|
||||
window. Default: ``kernel_size``.
|
||||
padding (int or tuple(int, int), optional): How much zero
|
||||
padding to apply to the input. The padding is applied on both sides
|
||||
of the height and width axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
@ -905,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
[
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
|
||||
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
|
||||
]
|
||||
)
|
||||
expected_max_pool_output_no_padding_stride_1 = [
|
||||
[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_no_padding_stride_2 = [
|
||||
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_padding_1_stride_2 = [
|
||||
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
[[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_padding_1_stride_2_kernel_3 = [
|
||||
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_avg_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[1.5000, 2.5000, 3.5000],
|
||||
[4.5000, 5.5000, 6.5000],
|
||||
[7.5000, 8.5000, 9.5000],
|
||||
],
|
||||
[
|
||||
[13.5000, 14.5000, 15.5000],
|
||||
[16.5000, 17.5000, 18.5000],
|
||||
[19.5000, 20.5000, 21.5000],
|
||||
],
|
||||
]
|
||||
expected_avg_pool_output_no_padding_stride_2 = [
|
||||
[[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]],
|
||||
[[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]],
|
||||
]
|
||||
expected_avg_pool_output_padding_1_stride_2 = [
|
||||
[
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[4.5000, 5.5000, 6.5000],
|
||||
[4.5000, 5.0000, 5.5000],
|
||||
],
|
||||
[
|
||||
[6.0000, 6.5000, 7.0000],
|
||||
[16.5000, 17.5000, 18.5000],
|
||||
[10.5000, 11.0000, 11.5000],
|
||||
],
|
||||
]
|
||||
expected_avg_pool_output_padding_1_kernel_3 = [
|
||||
[[1, 1.66667, 2.33333], [6, 7, 8]],
|
||||
[[9, 9.66667, 10.3333], [18, 19, 20]],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1_stride_2_kernel_3,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output_padding_1_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output_padding_1_kernel_3,
|
||||
)
|
||||
)
|
||||
# Test 2d pooling
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 16], [1, 17], [2, 18], [3, 19]],
|
||||
[[4, 20], [5, 21], [6, 22], [7, 23]],
|
||||
[[8, 24], [9, 25], [10, 26], [11, 27]],
|
||||
[[12, 28], [13, 29], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
)
|
||||
expected_max_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[[5, 21], [6, 22], [7, 23]],
|
||||
[[9, 25], [10, 26], [11, 27]],
|
||||
[[13, 29], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
expected_max_pool_output_no_padding_stride_2 = [
|
||||
[[[5, 21], [7, 23]], [[13, 29], [15, 31]]]
|
||||
]
|
||||
expected_max_pool_output_padding_1 = [
|
||||
[
|
||||
[[0, 16], [2, 18], [3, 19]],
|
||||
[[8, 24], [10, 26], [11, 27]],
|
||||
[[12, 28], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]],
|
||||
[[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]],
|
||||
[[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_no_padding_stride_2 = [
|
||||
[
|
||||
[[2.5000, 18.5000], [4.5000, 20.5000]],
|
||||
[[10.5000, 26.5000], [12.5000, 28.5000]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_padding_1 = [
|
||||
[
|
||||
[[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]],
|
||||
[[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]],
|
||||
[[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]],
|
||||
]
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
# Average pooling
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_mean_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_mean_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_mean_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
# Test multiple batches
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7]],
|
||||
[[8, 9], [10, 11], [12, 13], [14, 15]],
|
||||
[[16, 17], [18, 19], [20, 21], [22, 23]],
|
||||
[[24, 25], [26, 27], [28, 29], [30, 31]],
|
||||
],
|
||||
[
|
||||
[[32, 33], [34, 35], [36, 37], [38, 39]],
|
||||
[[40, 41], [42, 43], [44, 45], [46, 47]],
|
||||
[[48, 49], [50, 51], [52, 53], [54, 55]],
|
||||
[[56, 57], [58, 59], [60, 61], [62, 63]],
|
||||
],
|
||||
]
|
||||
)
|
||||
expected_max_pool_output = [
|
||||
[[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]],
|
||||
[[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]],
|
||||
]
|
||||
expected_avg_pool_output = [
|
||||
[[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]],
|
||||
[[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_max_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output,
|
||||
)
|
||||
)
|
||||
# Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2)
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
|
||||
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
|
||||
[[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]],
|
||||
[[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]],
|
||||
],
|
||||
[
|
||||
[[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]],
|
||||
[[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]],
|
||||
[[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]],
|
||||
[[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]],
|
||||
],
|
||||
]
|
||||
)
|
||||
expected_irregular_max_pool_output = [
|
||||
[
|
||||
[
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
],
|
||||
[
|
||||
[39.0, 40.0, 41.0],
|
||||
[42.0, 43.0, 44.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[51.0, 52.0, 53.0],
|
||||
[54.0, 55.0, 56.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
],
|
||||
[
|
||||
[87.0, 88.0, 89.0],
|
||||
[90.0, 91.0, 92.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
expected_irregular_average_pool_output = [
|
||||
[
|
||||
[
|
||||
[0.3750, 0.6250, 0.8750],
|
||||
[1.1250, 1.5000, 1.8750],
|
||||
[2.2500, 2.7500, 3.2500],
|
||||
[2.2500, 2.6250, 3.0000],
|
||||
[1.8750, 2.1250, 2.3750],
|
||||
],
|
||||
[
|
||||
[15.7500, 16.2500, 16.7500],
|
||||
[24.7500, 25.5000, 26.2500],
|
||||
[34.5000, 35.5000, 36.5000],
|
||||
[27.0000, 27.7500, 28.5000],
|
||||
[18.7500, 19.2500, 19.7500],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[12.3750, 12.6250, 12.8750],
|
||||
[19.1250, 19.5000, 19.8750],
|
||||
[26.2500, 26.7500, 27.2500],
|
||||
[20.2500, 20.6250, 21.0000],
|
||||
[13.8750, 14.1250, 14.3750],
|
||||
],
|
||||
[
|
||||
[39.7500, 40.2500, 40.7500],
|
||||
[60.7500, 61.5000, 62.2500],
|
||||
[82.5000, 83.5000, 84.5000],
|
||||
[63.0000, 63.7500, 64.5000],
|
||||
[42.7500, 43.2500, 43.7500],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
|
||||
expected_irregular_max_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
|
||||
expected_irregular_average_pool_output,
|
||||
)
|
||||
)
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.MaxPool1d(kernel_size=3, padding=2)),
|
||||
"MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.AvgPool1d(kernel_size=2, stride=3)),
|
||||
"AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
|
||||
"MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user