diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 786c9042c..560555aed 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. +- Vincent Amato: Added `AdaptiveAvgPool2d` and `AdaptiveAvgPool3d` layers. diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..5f475fcc9 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -77,6 +77,8 @@ from mlx.nn.layers.normalization import ( RMSNorm, ) from mlx.nn.layers.pooling import ( + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, AvgPool1d, AvgPool2d, AvgPool3d, diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 203117634..098e6f987 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -396,3 +396,237 @@ class AvgPool3d(_Pool3d): padding: Optional[Union[int, Tuple[int, int, int]]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class AdaptiveAvgPool2d(Module): + r"""Applies 2-dimensional adaptive average pooling. + + The output size is H x W, for any input size. The number of output + features is equal to the number of input planes. + + Args: + output_size: the target output size of the form H x W. + Can be a tuple (H, W) or a single int for a square image. + H and W can be either an ``int``, or ``None`` which means the size + will be the same as that of the input. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.AdaptiveAvgPool2d((5, 7)) + >>> pool(x) + >>> pool = nn.AdaptiveAvgPool2d(7) + >>> pool(x) + """ + + def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]): + super().__init__() + self.output_size = output_size + + def __call__(self, x): + return adaptive_avg_pool2d(x, self.output_size) + + +def adaptive_avg_pool2d( + x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]] +) -> mx.array: + r"""Apply 2-dimensional adaptive average pooling. + + Args: + x: Input array of shape (N, H, W, C) or (H, W, C). + output_size: Target output size (H, W) or single int for square output. + Values can be None to keep the corresponding input dimension. + + Returns: + Output array with spatial dimensions matching output_size. + """ + # Parse output_size + if isinstance(output_size, int): + output_size = (output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + # Get input dimensions + *batch_dims, H, W, C = x.shape + + # Handle None values in output_size + output_H = H if output_size[0] is None else output_size[0] + output_W = W if output_size[1] is None else output_size[1] + + # If already the right size, return as is + if H == output_H and W == output_W: + return x + + # Calculate kernel size and stride + kernel_H = H // output_H + kernel_W = W // output_W + stride_H = H // output_H + stride_W = W // output_W + + # For exact division, use regular pooling + if H % output_H == 0 and W % output_W == 0: + # Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C) + new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C] + x_reshaped = x.reshape(new_shape) + + # Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2) + result = mx.mean( + x_reshaped, axis=[-4, -2] + ) # Average over kernel_H and kernel_W + return result + + # For non-exact division, use strided approach with overlap + else: + # Calculate actual stride to fit exactly + stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 + stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 + + # Collect all averaged values + values = [] + for i in range(output_H): + row_values = [] + for j in range(output_W): + h_start = i * stride_H + h_end = min(h_start + kernel_H, H) + w_start = j * stride_W + w_end = min(w_start + kernel_W, W) + + # Extract region and average + region = x[..., h_start:h_end, w_start:w_end, :] + avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W + row_values.append(avg_val) + values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension + + # Stack all rows along H dimension + result = mx.stack(values, axis=-3) + return result + + +class AdaptiveAvgPool3d(Module): + r"""Applies 3-dimensional adaptive average pooling. + + The output size is D x H x W, for any input size. The number of output + features is equal to the number of input planes. + + Args: + output_size: the target output size of the form D x H x W. + Can be a tuple (D, H, W) or a single int for a cube D x D x D. + D, H and W can be either an ``int``, or ``None`` which means the size + will be the same as that of the input. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.AdaptiveAvgPool3d((5, 7, 9)) + >>> pool(x) + >>> pool = nn.AdaptiveAvgPool3d(7) + >>> pool(x) + """ + + def __init__( + self, + output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], + ): + super().__init__() + self.output_size = output_size + + def __call__(self, x): + return adaptive_avg_pool3d(x, self.output_size) + + +def adaptive_avg_pool3d( + x: mx.array, + output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], +) -> mx.array: + r"""Apply 3-dimensional adaptive average pooling. + + Args: + x: Input array of shape (N, D, H, W, C) or (D, H, W, C). + output_size: Target output size (D, H, W) or single int for cube output. + Values can be None to keep the corresponding input dimension. + + Returns: + Output array with spatial dimensions matching output_size. + """ + # Parse output_size + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0], output_size[0]) + elif len(output_size) == 2: + output_size = (output_size[0], output_size[1], output_size[1]) + + # Get input dimensions + *batch_dims, D, H, W, C = x.shape + + # Handle None values in output_size + output_D = D if output_size[0] is None else output_size[0] + output_H = H if output_size[1] is None else output_size[1] + output_W = W if output_size[2] is None else output_size[2] + + # If already the right size, return as is + if D == output_D and H == output_H and W == output_W: + return x + + # Calculate kernel size and stride + kernel_D = D // output_D + kernel_H = H // output_H + kernel_W = W // output_W + + # For exact division, use regular pooling + if D % output_D == 0 and H % output_H == 0 and W % output_W == 0: + # Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C) + new_shape = batch_dims + [ + output_D, + kernel_D, + output_H, + kernel_H, + output_W, + kernel_W, + C, + ] + x_reshaped = x.reshape(new_shape) + + # Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2) + result = mx.mean( + x_reshaped, axis=[-6, -4, -2] + ) # Average over kernel_D, kernel_H and kernel_W + return result + + # For non-exact division, use strided approach with overlap + else: + # Calculate actual stride to fit exactly + stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1 + stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 + stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 + + # Collect all averaged values + values = [] + for i in range(output_D): + depth_values = [] + for j in range(output_H): + row_values = [] + for k in range(output_W): + d_start = i * stride_D + d_end = min(d_start + kernel_D, D) + h_start = j * stride_H + h_end = min(h_start + kernel_H, H) + w_start = k * stride_W + w_end = min(w_start + kernel_W, W) + + # Extract region and average + region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :] + avg_val = mx.mean( + region, axis=(-4, -3, -2) + ) # Average over D, H and W + row_values.append(avg_val) + depth_values.append( + mx.stack(row_values, axis=-2) + ) # Stack along W dimension + values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension + + # Stack all depths along D dimension + result = mx.stack(values, axis=-4) + return result diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 296f6ee8d..b4b78e0fc 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1818,6 +1818,79 @@ class TestLayers(mlx_tests.MLXTestCase): "AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", ) + # Test AdaptiveAvgPool2d + # Test exact division case (8x8 -> 4x4) + x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1) + pool = nn.AdaptiveAvgPool2d((4, 4)) + output = pool(x) + self.assertEqual(output.shape, (1, 4, 4, 1)) + # Check first few values manually + expected_00 = (0 + 1 + 8 + 9) / 4 # Top-left 2x2 block + expected_01 = (2 + 3 + 10 + 11) / 4 # Next 2x2 block + self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5) + self.assertAlmostEqual(output[0, 0, 1, 0].item(), expected_01, places=5) + + # Test square output format (int input) + pool = nn.AdaptiveAvgPool2d(2) + output = pool(x) + self.assertEqual(output.shape, (1, 2, 2, 1)) + + # Test None values (keep original dimension) + x = mx.random.normal((2, 6, 8, 3)) + pool = nn.AdaptiveAvgPool2d((3, None)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 8, 3)) + + pool = nn.AdaptiveAvgPool2d((None, 4)) + output = pool(x) + self.assertEqual(output.shape, (2, 6, 4, 3)) + + # Test non-exact division case + x = mx.random.normal((1, 7, 5, 2)) + pool = nn.AdaptiveAvgPool2d((3, 2)) + output = pool(x) + self.assertEqual(output.shape, (1, 3, 2, 2)) + + # Test 3D input (no batch dimension) + x = mx.random.normal((6, 6, 4)) + pool = nn.AdaptiveAvgPool2d((2, 3)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 4)) + + # Test AdaptiveAvgPool3d + # Test exact division case (8x8x8 -> 4x4x4) + x = mx.arange(512, dtype=mx.float32).reshape(1, 8, 8, 8, 1) + pool = nn.AdaptiveAvgPool3d((4, 4, 4)) + output = pool(x) + self.assertEqual(output.shape, (1, 4, 4, 4, 1)) + + # Test cube output format (int input) + pool = nn.AdaptiveAvgPool3d(2) + output = pool(x) + self.assertEqual(output.shape, (1, 2, 2, 2, 1)) + + # Test None values (keep original dimensions) + x = mx.random.normal((2, 6, 8, 10, 3)) + pool = nn.AdaptiveAvgPool3d((3, None, 5)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 8, 5, 3)) + + pool = nn.AdaptiveAvgPool3d((None, 4, None)) + output = pool(x) + self.assertEqual(output.shape, (2, 6, 4, 10, 3)) + + # Test non-exact division case + x = mx.random.normal((1, 7, 5, 9, 2)) + pool = nn.AdaptiveAvgPool3d((3, 2, 4)) + output = pool(x) + self.assertEqual(output.shape, (1, 3, 2, 4, 2)) + + # Test 4D input (no batch dimension) + x = mx.random.normal((6, 6, 6, 4)) + pool = nn.AdaptiveAvgPool3d((2, 3, 2)) + output = pool(x) + self.assertEqual(output.shape, (2, 3, 2, 4)) + def test_set_dtype(self): def assert_dtype(layer, dtype): for k, v in tree_flatten(layer.parameters()):