diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4eb14b088..c9f5e090e 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -9,6 +9,9 @@ Layers :toctree: _autosummary :template: nn-module-template.rst + AdaptiveMaxPool1d + AdaptiveMaxPool2d + AdaptiveMaxPool3d ALiBi AvgPool1d AvgPool2d diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..a478070a3 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -77,6 +77,9 @@ from mlx.nn.layers.normalization import ( RMSNorm, ) from mlx.nn.layers.pooling import ( + AdaptiveMaxPool1d, + AdaptiveMaxPool2d, + AdaptiveMaxPool3d, AvgPool1d, AvgPool2d, AvgPool3d, diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 203117634..b660c5ac8 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -396,3 +396,266 @@ class AvgPool3d(_Pool3d): padding: Optional[Union[int, Tuple[int, int, int]]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class _AdaptivePool(Module): + """Base class for adaptive pooling layers.""" + + def __init__(self, output_size): + super().__init__() + self.output_size = output_size + + def _extra_repr(self): + return f"output_size={self.output_size}" + + +class AdaptiveMaxPool1d(_AdaptivePool): + r"""Applies 1-dimensional adaptive max pooling. + + Spatially downsamples the input by taking the maximum over pooling regions + such that the output size is L, for any input size. + + The parameter can be: + + * a single ``int`` -- the target output size. + * ``None`` can be used to keep the input size unchanged. + + Args: + output_size (int or None): The target output size L. + Can be an ``int``, or ``None`` which means the size + will be the same as that of the input. + + Note: + Unlike PyTorch's implementation, this layer does not support + returning indices (`return_indices` parameter). This may be + added in a future version. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 4)) + >>> pool = nn.AdaptiveMaxPool1d(7) + >>> pool(x) + >>> pool = nn.AdaptiveMaxPool1d(None) # No change + >>> pool(x) + """ + + def __init__(self, output_size: Union[int, None]): + super().__init__(output_size) + + def __call__(self, x): + output_size = self.output_size + + *batch_dims, L, C = x.shape + + output_L = L if output_size is None else output_size + + if L == output_L: + return x + + kernel_L = max(1, L // output_L) + + if L % output_L == 0 and kernel_L > 0: + new_shape = batch_dims + [output_L, kernel_L, C] + x_reshaped = x.reshape(new_shape) + return mx.max(x_reshaped, axis=-2) + else: + stride_L = max(1, (L - kernel_L) // (output_L - 1)) if output_L > 1 else 1 + + values = [] + for i in range(output_L): + l_start = min(i * stride_L, L - 1) + l_end = min(l_start + kernel_L, L) + region = x[..., l_start:l_end, :] + values.append(mx.max(region, axis=-2)) + + return mx.stack(values, axis=-2) + + +class AdaptiveMaxPool2d(_AdaptivePool): + r"""Applies 2-dimensional adaptive max pooling. + + Spatially downsamples the input by taking the maximum over pooling regions + such that the output size is H x W, for any input size. + + The parameters can be: + + * a single ``int`` -- in which case the same value is used for both the + height and width dimensions, creating a square output. + * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height dimension, the second ``int`` for the width dimension. + * ``None`` can be used for either dimension to keep the input size unchanged. + + Args: + output_size (int or tuple(int, int)): The target output size of the form H x W. + Can be a tuple (H, W) or a single int for a square output. + H and W can be either an ``int``, or ``None`` which means the size + will be the same as that of the input. + + Note: + Unlike PyTorch's implementation, this layer does not support + returning indices (`return_indices` parameter). This may be + added in a future version. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.AdaptiveMaxPool2d((5, 7)) + >>> pool(x) + >>> pool = nn.AdaptiveMaxPool2d(7) + >>> pool(x) + """ + + def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]): + super().__init__(output_size) + + def __call__(self, x): + output_size = self.output_size + if isinstance(output_size, int): + output_size = (output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + *batch_dims, H, W, C = x.shape + + output_H = H if output_size[0] is None else output_size[0] + output_W = W if output_size[1] is None else output_size[1] + + if H == output_H and W == output_W: + return x + + kernel_H = max(1, H // output_H) + kernel_W = max(1, W // output_W) + + if H % output_H == 0 and W % output_W == 0 and kernel_H > 0 and kernel_W > 0: + new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C] + x_reshaped = x.reshape(new_shape) + return mx.max(x_reshaped, axis=[-4, -2]) + else: + stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1 + stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1 + + values = [] + for i in range(output_H): + row_values = [] + for j in range(output_W): + h_start = min(i * stride_H, H - 1) + h_end = min(h_start + kernel_H, H) + w_start = min(j * stride_W, W - 1) + w_end = min(w_start + kernel_W, W) + + region = x[..., h_start:h_end, w_start:w_end, :] + row_values.append(mx.max(region, axis=(-3, -2))) + values.append(mx.stack(row_values, axis=-2)) + + return mx.stack(values, axis=-3) + + +class AdaptiveMaxPool3d(_AdaptivePool): + r"""Applies 3-dimensional adaptive max pooling. + + Spatially downsamples the input by taking the maximum over pooling regions + such that the output size is D x H x W, for any input size. + + The parameters can be: + + * a single ``int`` -- in which case the same value is used for the depth, + height, and width dimensions, creating a cube output. + * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth dimension, the second ``int`` for the height dimension, and + the third ``int`` for the width dimension. + * ``None`` can be used for any dimension to keep the input size unchanged. + + Args: + output_size (int or tuple(int, int, int)): The target output size of the form D x H x W. + Can be a tuple (D, H, W) or a single int for a cube output. + D, H and W can be either an ``int``, or ``None`` which means the size + will be the same as that of the input. + + Note: + Unlike PyTorch's implementation, this layer does not support + returning indices (`return_indices` parameter). This may be + added in a future version. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.AdaptiveMaxPool3d((5, 7, 9)) + >>> pool(x) + >>> pool = nn.AdaptiveMaxPool3d(7) + >>> pool(x) + """ + + def __init__( + self, + output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], + ): + super().__init__(output_size) + + def __call__(self, x): + output_size = self.output_size + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0], output_size[0]) + elif len(output_size) == 2: + output_size = (output_size[0], output_size[1], output_size[1]) + + *batch_dims, D, H, W, C = x.shape + + output_D = D if output_size[0] is None else output_size[0] + output_H = H if output_size[1] is None else output_size[1] + output_W = W if output_size[2] is None else output_size[2] + + if D == output_D and H == output_H and W == output_W: + return x + + kernel_D = max(1, D // output_D) + kernel_H = max(1, H // output_H) + kernel_W = max(1, W // output_W) + + if ( + D % output_D == 0 + and H % output_H == 0 + and W % output_W == 0 + and kernel_D > 0 + and kernel_H > 0 + and kernel_W > 0 + ): + new_shape = batch_dims + [ + output_D, + kernel_D, + output_H, + kernel_H, + output_W, + kernel_W, + C, + ] + x_reshaped = x.reshape(new_shape) + return mx.max(x_reshaped, axis=[-6, -4, -2]) + else: + stride_D = max(1, (D - kernel_D) // (output_D - 1)) if output_D > 1 else 1 + stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1 + stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1 + + values = [] + for i in range(output_D): + depth_values = [] + for j in range(output_H): + row_values = [] + for k in range(output_W): + d_start = min(i * stride_D, D - 1) + d_end = min(d_start + kernel_D, D) + h_start = min(j * stride_H, H - 1) + h_end = min(h_start + kernel_H, H) + w_start = min(k * stride_W, W - 1) + w_end = min(w_start + kernel_W, W) + + region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :] + row_values.append(mx.max(region, axis=(-4, -3, -2))) + depth_values.append(mx.stack(row_values, axis=-2)) + values.append(mx.stack(depth_values, axis=-3)) + + return mx.stack(values, axis=-4) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 296f6ee8d..d56aab172 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1818,6 +1818,133 @@ class TestLayers(mlx_tests.MLXTestCase): "AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", ) + # Test AdaptiveMaxPool1d + # Test exact division case (8 -> 4) + x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1) + pool = nn.AdaptiveMaxPool1d(4) + output = pool(x) + self.assertEqual(output.shape, (1, 4, 1)) + # Check first few values manually + expected_0 = max(0, 1) # First 2 elements max + expected_1 = max(2, 3) # Next 2 elements max + self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5) + self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5) + + # Test None value (keep original dimension) + x = mx.random.normal((2, 8, 3)) + pool = nn.AdaptiveMaxPool1d(None) + output = pool(x) + self.assertEqual(output.shape, (2, 8, 3)) + # Should be identical to input when output_size is None + self.assertTrue(mx.allclose(output, x)) + + # Test non-exact division case + x = mx.random.normal((1, 7, 2)) + pool = nn.AdaptiveMaxPool1d(3) + output = pool(x) + self.assertEqual(output.shape, (1, 3, 2)) + + # Test 2D input (no batch dimension) + x = mx.random.normal((6, 4)) + pool = nn.AdaptiveMaxPool1d(2) + output = pool(x) + self.assertEqual(output.shape, (2, 4)) + + # Test single output + x = mx.random.normal((1, 10, 2)) + pool = nn.AdaptiveMaxPool1d(1) + output = pool(x) + self.assertEqual(output.shape, (1, 1, 2)) + + # Test larger output than input + x = mx.random.normal((1, 3, 2)) + pool = nn.AdaptiveMaxPool1d(5) + output = pool(x) + self.assertEqual(output.shape, (1, 5, 2)) + + # Test AdaptiveMaxPool2d + # Test exact division case (4x4 -> 2x2) + x = mx.arange(16, dtype=mx.float32).reshape(1, 4, 4, 1) + pool = nn.AdaptiveMaxPool2d((2, 2)) + output = pool(x) + self.assertEqual(output.shape, (1, 2, 2, 1)) + # Check first value manually - max of top-left 2x2 block + expected_00 = max(0, 1, 4, 5) # Top-left 2x2 block + self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5) + + # Test square output format (int input) + pool = nn.AdaptiveMaxPool2d(1) + output = pool(x) + self.assertEqual(output.shape, (1, 1, 1, 1)) + + # Test None values (keep original dimension) + x = mx.random.normal((2, 6, 8, 3)) + pool = nn.AdaptiveMaxPool2d((3, None)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 8, 3)) + + pool = nn.AdaptiveMaxPool2d((None, 4)) + output = pool(x) + self.assertEqual(output.shape, (2, 6, 4, 3)) + + # Test non-exact division case + x = mx.random.normal((1, 7, 5, 2)) + pool = nn.AdaptiveMaxPool2d((3, 2)) + output = pool(x) + self.assertEqual(output.shape, (1, 3, 2, 2)) + + # Test larger output than input + x = mx.random.normal((1, 2, 3, 2)) + pool = nn.AdaptiveMaxPool2d((4, 5)) + output = pool(x) + self.assertEqual(output.shape, (1, 4, 5, 2)) + + # Test 3D input (no batch dimension) + x = mx.random.normal((6, 6, 4)) + pool = nn.AdaptiveMaxPool2d((2, 3)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 4)) + + # Test AdaptiveMaxPool3d + # Test exact division case (4x4x4 -> 2x2x2) + x = mx.arange(64, dtype=mx.float32).reshape(1, 4, 4, 4, 1) + pool = nn.AdaptiveMaxPool3d((2, 2, 2)) + output = pool(x) + self.assertEqual(output.shape, (1, 2, 2, 2, 1)) + + # Test cube output format (int input) + pool = nn.AdaptiveMaxPool3d(1) + output = pool(x) + self.assertEqual(output.shape, (1, 1, 1, 1, 1)) + + # Test None values (keep original dimensions) + x = mx.random.normal((2, 6, 8, 10, 3)) + pool = nn.AdaptiveMaxPool3d((3, None, 5)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 8, 5, 3)) + + pool = nn.AdaptiveMaxPool3d((None, 4, None)) + output = pool(x) + self.assertEqual(output.shape, (2, 6, 4, 10, 3)) + + # Test non-exact division case + x = mx.random.normal((1, 7, 5, 9, 2)) + pool = nn.AdaptiveMaxPool3d((3, 2, 4)) + output = pool(x) + self.assertEqual(output.shape, (1, 3, 2, 4, 2)) + + # Test larger output than input + x = mx.random.normal((1, 2, 3, 2, 2)) + pool = nn.AdaptiveMaxPool3d((4, 5, 3)) + output = pool(x) + self.assertEqual(output.shape, (1, 4, 5, 3, 2)) + + # Test 4D input (no batch dimension) + x = mx.random.normal((6, 6, 6, 4)) + pool = nn.AdaptiveMaxPool3d((2, 3, 2)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 2, 4)) + def test_set_dtype(self): def assert_dtype(layer, dtype): for k, v in tree_flatten(layer.parameters()):