Refactor adaptive pooling for style consistency

This commit is contained in:
Vincent Amato
2025-08-11 22:12:14 -04:00
parent 634ce07a3e
commit 89b3f69a56

View File

@@ -398,21 +398,40 @@ class AvgPool3d(_Pool3d):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
class AdaptiveAvgPool2d(Module):
class _AdaptivePool(Module):
"""Base class for adaptive pooling layers."""
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
def _extra_repr(self):
return f"output_size={self.output_size}"
class AdaptiveAvgPool2d(_AdaptivePool):
r"""Applies 2-dimensional adaptive average pooling.
The output size is H x W, for any input size. The number of output
features is equal to the number of input planes.
Spatially downsamples the input by taking the average over pooling regions
such that the output size is H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for both the
height and width dimensions, creating a square output.
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height dimension, the second ``int`` for the width dimension.
* ``None`` can be used for either dimension to keep the input size unchanged.
Args:
output_size: the target output size of the form H x W.
Can be a tuple (H, W) or a single int for a square image.
output_size (int or tuple(int, int)): The target output size of the form H x W.
Can be a tuple (H, W) or a single int for a square output.
H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
>>> pool(x)
@@ -421,68 +440,34 @@ class AdaptiveAvgPool2d(Module):
"""
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
super().__init__()
self.output_size = output_size
super().__init__(output_size)
def __call__(self, x):
return adaptive_avg_pool2d(x, self.output_size)
def adaptive_avg_pool2d(
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
) -> mx.array:
r"""Apply 2-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, H, W, C) or (H, W, C).
output_size: Target output size (H, W) or single int for square output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
output_size = self.output_size
if isinstance(output_size, int):
output_size = (output_size, output_size)
elif len(output_size) == 1:
output_size = (output_size[0], output_size[0])
# Get input dimensions
*batch_dims, H, W, C = x.shape
# Handle None values in output_size
output_H = H if output_size[0] is None else output_size[0]
output_W = W if output_size[1] is None else output_size[1]
# If already the right size, return as is
if H == output_H and W == output_W:
return x
# Calculate kernel size and stride
kernel_H = H // output_H
kernel_W = W // output_W
stride_H = H // output_H
stride_W = W // output_W
# For exact division, use regular pooling
if H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
x_reshaped = x.reshape(new_shape)
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-4, -2]
) # Average over kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
return mx.mean(x_reshaped, axis=[-4, -2])
else:
# Calculate actual stride to fit exactly
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = []
for i in range(output_H):
row_values = []
@@ -492,32 +477,37 @@ def adaptive_avg_pool2d(
w_start = j * stride_W
w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., h_start:h_end, w_start:w_end, :]
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W
row_values.append(avg_val)
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
row_values.append(mx.mean(region, axis=(-3, -2)))
values.append(mx.stack(row_values, axis=-2))
# Stack all rows along H dimension
result = mx.stack(values, axis=-3)
return result
return mx.stack(values, axis=-3)
class AdaptiveAvgPool3d(Module):
class AdaptiveAvgPool3d(_AdaptivePool):
r"""Applies 3-dimensional adaptive average pooling.
The output size is D x H x W, for any input size. The number of output
features is equal to the number of input planes.
Spatially downsamples the input by taking the average over pooling regions
such that the output size is D x H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for the depth,
height, and width dimensions, creating a cube output.
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth dimension, the second ``int`` for the height dimension, and
the third ``int`` for the width dimension.
* ``None`` can be used for any dimension to keep the input size unchanged.
Args:
output_size: the target output size of the form D x H x W.
Can be a tuple (D, H, W) or a single int for a cube D x D x D.
output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
Can be a tuple (D, H, W) or a single int for a cube output.
D, H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
>>> pool(x)
@@ -529,28 +519,10 @@ class AdaptiveAvgPool3d(Module):
self,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
):
super().__init__()
self.output_size = output_size
super().__init__(output_size)
def __call__(self, x):
return adaptive_avg_pool3d(x, self.output_size)
def adaptive_avg_pool3d(
x: mx.array,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
) -> mx.array:
r"""Apply 3-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
output_size: Target output size (D, H, W) or single int for cube output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
output_size = self.output_size
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)
elif len(output_size) == 1:
@@ -558,26 +530,20 @@ def adaptive_avg_pool3d(
elif len(output_size) == 2:
output_size = (output_size[0], output_size[1], output_size[1])
# Get input dimensions
*batch_dims, D, H, W, C = x.shape
# Handle None values in output_size
output_D = D if output_size[0] is None else output_size[0]
output_H = H if output_size[1] is None else output_size[1]
output_W = W if output_size[2] is None else output_size[2]
# If already the right size, return as is
if D == output_D and H == output_H and W == output_W:
return x
# Calculate kernel size and stride
kernel_D = D // output_D
kernel_H = H // output_H
kernel_W = W // output_W
# For exact division, use regular pooling
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [
output_D,
kernel_D,
@@ -588,21 +554,12 @@ def adaptive_avg_pool3d(
C,
]
x_reshaped = x.reshape(new_shape)
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-6, -4, -2]
) # Average over kernel_D, kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
return mx.mean(x_reshaped, axis=[-6, -4, -2])
else:
# Calculate actual stride to fit exactly
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = []
for i in range(output_D):
depth_values = []
@@ -616,17 +573,9 @@ def adaptive_avg_pool3d(
w_start = k * stride_W
w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
avg_val = mx.mean(
region, axis=(-4, -3, -2)
) # Average over D, H and W
row_values.append(avg_val)
depth_values.append(
mx.stack(row_values, axis=-2)
) # Stack along W dimension
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
row_values.append(mx.mean(region, axis=(-4, -3, -2)))
depth_values.append(mx.stack(row_values, axis=-2))
values.append(mx.stack(depth_values, axis=-3))
# Stack all depths along D dimension
result = mx.stack(values, axis=-4)
return result
return mx.stack(values, axis=-4)