From 89b3f69a56630e43ccdf28eacfb4e08563677eb3 Mon Sep 17 00:00:00 2001 From: Vincent Amato Date: Mon, 11 Aug 2025 22:12:14 -0400 Subject: [PATCH] Refactor adaptive pooling for style consistency --- python/mlx/nn/layers/pooling.py | 311 +++++++++++++------------------- 1 file changed, 130 insertions(+), 181 deletions(-) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 098e6f987..68171430e 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -398,21 +398,40 @@ class AvgPool3d(_Pool3d): super().__init__(mx.mean, 0, kernel_size, stride, padding) -class AdaptiveAvgPool2d(Module): +class _AdaptivePool(Module): + """Base class for adaptive pooling layers.""" + + def __init__(self, output_size): + super().__init__() + self.output_size = output_size + + def _extra_repr(self): + return f"output_size={self.output_size}" + + +class AdaptiveAvgPool2d(_AdaptivePool): r"""Applies 2-dimensional adaptive average pooling. - The output size is H x W, for any input size. The number of output - features is equal to the number of input planes. + Spatially downsamples the input by taking the average over pooling regions + such that the output size is H x W, for any input size. + + The parameters can be: + + * a single ``int`` -- in which case the same value is used for both the + height and width dimensions, creating a square output. + * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height dimension, the second ``int`` for the width dimension. + * ``None`` can be used for either dimension to keep the input size unchanged. Args: - output_size: the target output size of the form H x W. - Can be a tuple (H, W) or a single int for a square image. + output_size (int or tuple(int, int)): The target output size of the form H x W. + Can be a tuple (H, W) or a single int for a square output. H and W can be either an ``int``, or ``None`` which means the size will be the same as that of the input. Examples: >>> import mlx.core as mx - >>> import mlx.nn as nn + >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 32, 32, 4)) >>> pool = nn.AdaptiveAvgPool2d((5, 7)) >>> pool(x) @@ -421,103 +440,74 @@ class AdaptiveAvgPool2d(Module): """ def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]): - super().__init__() - self.output_size = output_size + super().__init__(output_size) def __call__(self, x): - return adaptive_avg_pool2d(x, self.output_size) + output_size = self.output_size + if isinstance(output_size, int): + output_size = (output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + *batch_dims, H, W, C = x.shape + + output_H = H if output_size[0] is None else output_size[0] + output_W = W if output_size[1] is None else output_size[1] + + if H == output_H and W == output_W: + return x + + kernel_H = H // output_H + kernel_W = W // output_W + + if H % output_H == 0 and W % output_W == 0: + new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C] + x_reshaped = x.reshape(new_shape) + return mx.mean(x_reshaped, axis=[-4, -2]) + else: + stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 + stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 + + values = [] + for i in range(output_H): + row_values = [] + for j in range(output_W): + h_start = i * stride_H + h_end = min(h_start + kernel_H, H) + w_start = j * stride_W + w_end = min(w_start + kernel_W, W) + + region = x[..., h_start:h_end, w_start:w_end, :] + row_values.append(mx.mean(region, axis=(-3, -2))) + values.append(mx.stack(row_values, axis=-2)) + + return mx.stack(values, axis=-3) -def adaptive_avg_pool2d( - x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]] -) -> mx.array: - r"""Apply 2-dimensional adaptive average pooling. - - Args: - x: Input array of shape (N, H, W, C) or (H, W, C). - output_size: Target output size (H, W) or single int for square output. - Values can be None to keep the corresponding input dimension. - - Returns: - Output array with spatial dimensions matching output_size. - """ - # Parse output_size - if isinstance(output_size, int): - output_size = (output_size, output_size) - elif len(output_size) == 1: - output_size = (output_size[0], output_size[0]) - - # Get input dimensions - *batch_dims, H, W, C = x.shape - - # Handle None values in output_size - output_H = H if output_size[0] is None else output_size[0] - output_W = W if output_size[1] is None else output_size[1] - - # If already the right size, return as is - if H == output_H and W == output_W: - return x - - # Calculate kernel size and stride - kernel_H = H // output_H - kernel_W = W // output_W - stride_H = H // output_H - stride_W = W // output_W - - # For exact division, use regular pooling - if H % output_H == 0 and W % output_W == 0: - # Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C) - new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C] - x_reshaped = x.reshape(new_shape) - - # Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2) - result = mx.mean( - x_reshaped, axis=[-4, -2] - ) # Average over kernel_H and kernel_W - return result - - # For non-exact division, use strided approach with overlap - else: - # Calculate actual stride to fit exactly - stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 - stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 - - # Collect all averaged values - values = [] - for i in range(output_H): - row_values = [] - for j in range(output_W): - h_start = i * stride_H - h_end = min(h_start + kernel_H, H) - w_start = j * stride_W - w_end = min(w_start + kernel_W, W) - - # Extract region and average - region = x[..., h_start:h_end, w_start:w_end, :] - avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W - row_values.append(avg_val) - values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension - - # Stack all rows along H dimension - result = mx.stack(values, axis=-3) - return result - - -class AdaptiveAvgPool3d(Module): +class AdaptiveAvgPool3d(_AdaptivePool): r"""Applies 3-dimensional adaptive average pooling. - The output size is D x H x W, for any input size. The number of output - features is equal to the number of input planes. + Spatially downsamples the input by taking the average over pooling regions + such that the output size is D x H x W, for any input size. + + The parameters can be: + + * a single ``int`` -- in which case the same value is used for the depth, + height, and width dimensions, creating a cube output. + * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth dimension, the second ``int`` for the height dimension, and + the third ``int`` for the width dimension. + * ``None`` can be used for any dimension to keep the input size unchanged. Args: - output_size: the target output size of the form D x H x W. - Can be a tuple (D, H, W) or a single int for a cube D x D x D. + output_size (int or tuple(int, int, int)): The target output size of the form D x H x W. + Can be a tuple (D, H, W) or a single int for a cube output. D, H and W can be either an ``int``, or ``None`` which means the size will be the same as that of the input. Examples: >>> import mlx.core as mx - >>> import mlx.nn as nn + >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) >>> pool = nn.AdaptiveAvgPool3d((5, 7, 9)) >>> pool(x) @@ -529,104 +519,63 @@ class AdaptiveAvgPool3d(Module): self, output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], ): - super().__init__() - self.output_size = output_size + super().__init__(output_size) def __call__(self, x): - return adaptive_avg_pool3d(x, self.output_size) + output_size = self.output_size + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + elif len(output_size) == 1: + output_size = (output_size[0], output_size[0], output_size[0]) + elif len(output_size) == 2: + output_size = (output_size[0], output_size[1], output_size[1]) + *batch_dims, D, H, W, C = x.shape -def adaptive_avg_pool3d( - x: mx.array, - output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], -) -> mx.array: - r"""Apply 3-dimensional adaptive average pooling. + output_D = D if output_size[0] is None else output_size[0] + output_H = H if output_size[1] is None else output_size[1] + output_W = W if output_size[2] is None else output_size[2] - Args: - x: Input array of shape (N, D, H, W, C) or (D, H, W, C). - output_size: Target output size (D, H, W) or single int for cube output. - Values can be None to keep the corresponding input dimension. + if D == output_D and H == output_H and W == output_W: + return x - Returns: - Output array with spatial dimensions matching output_size. - """ - # Parse output_size - if isinstance(output_size, int): - output_size = (output_size, output_size, output_size) - elif len(output_size) == 1: - output_size = (output_size[0], output_size[0], output_size[0]) - elif len(output_size) == 2: - output_size = (output_size[0], output_size[1], output_size[1]) + kernel_D = D // output_D + kernel_H = H // output_H + kernel_W = W // output_W - # Get input dimensions - *batch_dims, D, H, W, C = x.shape + if D % output_D == 0 and H % output_H == 0 and W % output_W == 0: + new_shape = batch_dims + [ + output_D, + kernel_D, + output_H, + kernel_H, + output_W, + kernel_W, + C, + ] + x_reshaped = x.reshape(new_shape) + return mx.mean(x_reshaped, axis=[-6, -4, -2]) + else: + stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1 + stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 + stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 - # Handle None values in output_size - output_D = D if output_size[0] is None else output_size[0] - output_H = H if output_size[1] is None else output_size[1] - output_W = W if output_size[2] is None else output_size[2] + values = [] + for i in range(output_D): + depth_values = [] + for j in range(output_H): + row_values = [] + for k in range(output_W): + d_start = i * stride_D + d_end = min(d_start + kernel_D, D) + h_start = j * stride_H + h_end = min(h_start + kernel_H, H) + w_start = k * stride_W + w_end = min(w_start + kernel_W, W) - # If already the right size, return as is - if D == output_D and H == output_H and W == output_W: - return x + region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :] + row_values.append(mx.mean(region, axis=(-4, -3, -2))) + depth_values.append(mx.stack(row_values, axis=-2)) + values.append(mx.stack(depth_values, axis=-3)) - # Calculate kernel size and stride - kernel_D = D // output_D - kernel_H = H // output_H - kernel_W = W // output_W - - # For exact division, use regular pooling - if D % output_D == 0 and H % output_H == 0 and W % output_W == 0: - # Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C) - new_shape = batch_dims + [ - output_D, - kernel_D, - output_H, - kernel_H, - output_W, - kernel_W, - C, - ] - x_reshaped = x.reshape(new_shape) - - # Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2) - result = mx.mean( - x_reshaped, axis=[-6, -4, -2] - ) # Average over kernel_D, kernel_H and kernel_W - return result - - # For non-exact division, use strided approach with overlap - else: - # Calculate actual stride to fit exactly - stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1 - stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 - stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 - - # Collect all averaged values - values = [] - for i in range(output_D): - depth_values = [] - for j in range(output_H): - row_values = [] - for k in range(output_W): - d_start = i * stride_D - d_end = min(d_start + kernel_D, D) - h_start = j * stride_H - h_end = min(h_start + kernel_H, H) - w_start = k * stride_W - w_end = min(w_start + kernel_W, W) - - # Extract region and average - region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :] - avg_val = mx.mean( - region, axis=(-4, -3, -2) - ) # Average over D, H and W - row_values.append(avg_val) - depth_values.append( - mx.stack(row_values, axis=-2) - ) # Stack along W dimension - values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension - - # Stack all depths along D dimension - result = mx.stack(values, axis=-4) - return result + return mx.stack(values, axis=-4)