diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 18a8c5599..2a3c6c612 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support -- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index ef099ef2f..0f5fca9db 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -10,6 +10,8 @@ Layers :template: nn-module-template.rst ALiBi + AvgPool1d + AvgPool2d BatchNorm Conv1d Conv2d @@ -22,6 +24,8 @@ Layers InstanceNorm LayerNorm Linear + MaxPool1d + MaxPool2d Mish MultiHeadAttention PReLU diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f5092418b..207cb01b2 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import ( LayerNorm, RMSNorm, ) +from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py new file mode 100644 index 000000000..ffa05f5d2 --- /dev/null +++ b/python/mlx/nn/layers/pooling.py @@ -0,0 +1,308 @@ +# Copyright © 2023-2024 Apple Inc. + +import operator +from itertools import accumulate +from typing import Optional, Tuple, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _value_or_list(x, n, msg): + if isinstance(x, (list, tuple)): + if len(x) != n: + raise ValueError(msg) + return list(x) + + if not isinstance(x, int): + raise ValueError(msg) + + return [x] * n + + +def _sliding_windows(x, window_shape, window_strides): + if x.ndim < 3: + raise ValueError( + f"To extract sliding windows at least 1 spatial dimension " + f"(3 total) is needed but the input only has {x.ndim} dimensions." + ) + + spatial_dims = x.shape[1:-1] + if not (len(spatial_dims) == len(window_shape) == len(window_strides)): + raise ValueError( + f"To extract sliding windows the window shapes and strides must have " + f"the same number of spatial dimensions as the signal but the signal " + f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} " + f"and strides have {len(window_strides)}." + ) + + shape = x.shape + strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:] + + # Compute the output shape + final_shape = [shape[0]] + final_shape += [ + (size - window) // stride + 1 + for size, window, stride in zip(spatial_dims, window_shape, window_strides) + ] + final_shape += window_shape + final_shape += [shape[-1]] + + # Compute the output strides + final_strides = strides[:1] + final_strides += [ + og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides) + ] + final_strides += strides[1:-1] + final_strides += strides[-1:] # should always be [1] + + return mx.as_strided(x, final_shape, final_strides) + + +class _Pool(Module): + def __init__(self, pooling_function, kernel_size, stride, padding, padding_value): + super().__init__() + + self._pooling_function = pooling_function + self._kernel_size = kernel_size + self._stride = stride + self._padding = padding + self._padding_value = padding_value + self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1)) + + def _extra_repr(self): + ks = tuple(self._kernel_size) + st = tuple(self._stride) + pd = tuple(p[0] for p in self._padding) + + return f"kernel_size={ks}, stride={st}, padding={pd}" + + def __call__(self, x): + if any(p[0] > 0 for p in self._padding): + x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value) + x = _sliding_windows(x, self._kernel_size, self._stride) + return self._pooling_function(x, self._axes) + + +class _Pool1d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 1 integer" + kernel_size = _value_or_list( + kernel_size, 1, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 1, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 1, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class _Pool2d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 2 integers" + kernel_size = _value_or_list( + kernel_size, 2, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 2, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 2, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class MaxPool1d(_Pool1d): + r"""Applies 1-dimensional max pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much negative infinity + padding to apply to the input. The padding amount is applied to + both sides of the spatial axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.MaxPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool1d(_Pool1d): + r"""Applies 1-dimensional average pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much zero padding to apply to + the input. The padding amount is applied to both sides of the spatial + axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.AvgPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class MaxPool2d(_Pool2d): + r"""Applies 2-dimensional max pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much negative infinity + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool2d(_Pool2d): + r"""Applies 2-dimensional average pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much zero + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 201665f7f..eaaf3bb9c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -905,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(y.shape, x.shape) self.assertTrue(y.dtype, mx.float16) + def test_pooling(self): + # Test 1d pooling + x = mx.array( + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2 = [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2_kernel_3 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_avg_pool_output_no_padding_stride_1 = [ + [ + [1.5000, 2.5000, 3.5000], + [4.5000, 5.5000, 6.5000], + [7.5000, 8.5000, 9.5000], + ], + [ + [13.5000, 14.5000, 15.5000], + [16.5000, 17.5000, 18.5000], + [19.5000, 20.5000, 21.5000], + ], + ] + expected_avg_pool_output_no_padding_stride_2 = [ + [[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]], + [[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]], + ] + expected_avg_pool_output_padding_1_stride_2 = [ + [ + [0.0000, 0.5000, 1.0000], + [4.5000, 5.5000, 6.5000], + [4.5000, 5.0000, 5.5000], + ], + [ + [6.0000, 6.5000, 7.0000], + [16.5000, 17.5000, 18.5000], + [10.5000, 11.0000, 11.5000], + ], + ] + expected_avg_pool_output_padding_1_kernel_3 = [ + [[1, 1.66667, 2.33333], [6, 7, 8]], + [[9, 9.66667, 10.3333], [18, 19, 20]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2_kernel_3, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_avg_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_avg_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_kernel_3, + ) + ) + # Test 2d pooling + x = mx.array( + [ + [ + [[0, 16], [1, 17], [2, 18], [3, 19]], + [[4, 20], [5, 21], [6, 22], [7, 23]], + [[8, 24], [9, 25], [10, 26], [11, 27]], + [[12, 28], [13, 29], [14, 30], [15, 31]], + ] + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [ + [[5, 21], [6, 22], [7, 23]], + [[9, 25], [10, 26], [11, 27]], + [[13, 29], [14, 30], [15, 31]], + ] + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[[5, 21], [7, 23]], [[13, 29], [15, 31]]] + ] + expected_max_pool_output_padding_1 = [ + [ + [[0, 16], [2, 18], [3, 19]], + [[8, 24], [10, 26], [11, 27]], + [[12, 28], [14, 30], [15, 31]], + ] + ] + expected_mean_pool_output_no_padding_stride_1 = [ + [ + [[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]], + [[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]], + [[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_no_padding_stride_2 = [ + [ + [[2.5000, 18.5000], [4.5000, 20.5000]], + [[10.5000, 26.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_padding_1 = [ + [ + [[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]], + [[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]], + [[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]], + ] + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1, + ) + ) + # Average pooling + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_mean_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_mean_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_mean_pool_output_padding_1, + ) + ) + # Test multiple batches + x = mx.array( + [ + [ + [[0, 1], [2, 3], [4, 5], [6, 7]], + [[8, 9], [10, 11], [12, 13], [14, 15]], + [[16, 17], [18, 19], [20, 21], [22, 23]], + [[24, 25], [26, 27], [28, 29], [30, 31]], + ], + [ + [[32, 33], [34, 35], [36, 37], [38, 39]], + [[40, 41], [42, 43], [44, 45], [46, 47]], + [[48, 49], [50, 51], [52, 53], [54, 55]], + [[56, 57], [58, 59], [60, 61], [62, 63]], + ], + ] + ) + expected_max_pool_output = [ + [[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]], + [[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]], + ] + expected_avg_pool_output = [ + [[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]], + [[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output, + ) + ) + # Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2) + x = mx.array( + [ + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + [[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]], + ], + [ + [[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]], + [[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]], + [[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]], + [[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]], + ], + ] + ) + expected_irregular_max_pool_output = [ + [ + [ + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + ], + [ + [39.0, 40.0, 41.0], + [42.0, 43.0, 44.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + ], + ], + [ + [ + [51.0, 52.0, 53.0], + [54.0, 55.0, 56.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + ], + [ + [87.0, 88.0, 89.0], + [90.0, 91.0, 92.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + ], + ], + ] + expected_irregular_average_pool_output = [ + [ + [ + [0.3750, 0.6250, 0.8750], + [1.1250, 1.5000, 1.8750], + [2.2500, 2.7500, 3.2500], + [2.2500, 2.6250, 3.0000], + [1.8750, 2.1250, 2.3750], + ], + [ + [15.7500, 16.2500, 16.7500], + [24.7500, 25.5000, 26.2500], + [34.5000, 35.5000, 36.5000], + [27.0000, 27.7500, 28.5000], + [18.7500, 19.2500, 19.7500], + ], + ], + [ + [ + [12.3750, 12.6250, 12.8750], + [19.1250, 19.5000, 19.8750], + [26.2500, 26.7500, 27.2500], + [20.2500, 20.6250, 21.0000], + [13.8750, 14.1250, 14.3750], + ], + [ + [39.7500, 40.2500, 40.7500], + [60.7500, 61.5000, 62.2500], + [82.5000, 83.5000, 84.5000], + [63.0000, 63.7500, 64.5000], + [42.7500, 43.2500, 43.7500], + ], + ], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_average_pool_output, + ) + ) + # Test repr + self.assertEqual( + str(nn.MaxPool1d(kernel_size=3, padding=2)), + "MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))", + ) + self.assertEqual( + str(nn.AvgPool1d(kernel_size=2, stride=3)), + "AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))", + ) + self.assertEqual( + str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + "MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))", + ) + self.assertEqual( + str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), + "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", + ) + if __name__ == "__main__": unittest.main()