Refactor adaptive pooling for style consistency

This commit is contained in:
Vincent Amato
2025-08-11 22:12:14 -04:00
parent 634ce07a3e
commit 89b3f69a56

View File

@@ -398,21 +398,40 @@ class AvgPool3d(_Pool3d):
super().__init__(mx.mean, 0, kernel_size, stride, padding) super().__init__(mx.mean, 0, kernel_size, stride, padding)
class AdaptiveAvgPool2d(Module): class _AdaptivePool(Module):
"""Base class for adaptive pooling layers."""
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
def _extra_repr(self):
return f"output_size={self.output_size}"
class AdaptiveAvgPool2d(_AdaptivePool):
r"""Applies 2-dimensional adaptive average pooling. r"""Applies 2-dimensional adaptive average pooling.
The output size is H x W, for any input size. The number of output Spatially downsamples the input by taking the average over pooling regions
features is equal to the number of input planes. such that the output size is H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for both the
height and width dimensions, creating a square output.
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height dimension, the second ``int`` for the width dimension.
* ``None`` can be used for either dimension to keep the input size unchanged.
Args: Args:
output_size: the target output size of the form H x W. output_size (int or tuple(int, int)): The target output size of the form H x W.
Can be a tuple (H, W) or a single int for a square image. Can be a tuple (H, W) or a single int for a square output.
H and W can be either an ``int``, or ``None`` which means the size H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input. will be the same as that of the input.
Examples: Examples:
>>> import mlx.core as mx >>> import mlx.core as mx
>>> import mlx.nn as nn >>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4)) >>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool2d((5, 7)) >>> pool = nn.AdaptiveAvgPool2d((5, 7))
>>> pool(x) >>> pool(x)
@@ -421,68 +440,34 @@ class AdaptiveAvgPool2d(Module):
""" """
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]): def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
super().__init__() super().__init__(output_size)
self.output_size = output_size
def __call__(self, x): def __call__(self, x):
return adaptive_avg_pool2d(x, self.output_size) output_size = self.output_size
def adaptive_avg_pool2d(
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
) -> mx.array:
r"""Apply 2-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, H, W, C) or (H, W, C).
output_size: Target output size (H, W) or single int for square output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = (output_size, output_size) output_size = (output_size, output_size)
elif len(output_size) == 1: elif len(output_size) == 1:
output_size = (output_size[0], output_size[0]) output_size = (output_size[0], output_size[0])
# Get input dimensions
*batch_dims, H, W, C = x.shape *batch_dims, H, W, C = x.shape
# Handle None values in output_size
output_H = H if output_size[0] is None else output_size[0] output_H = H if output_size[0] is None else output_size[0]
output_W = W if output_size[1] is None else output_size[1] output_W = W if output_size[1] is None else output_size[1]
# If already the right size, return as is
if H == output_H and W == output_W: if H == output_H and W == output_W:
return x return x
# Calculate kernel size and stride
kernel_H = H // output_H kernel_H = H // output_H
kernel_W = W // output_W kernel_W = W // output_W
stride_H = H // output_H
stride_W = W // output_W
# For exact division, use regular pooling
if H % output_H == 0 and W % output_W == 0: if H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C] new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
x_reshaped = x.reshape(new_shape) x_reshaped = x.reshape(new_shape)
return mx.mean(x_reshaped, axis=[-4, -2])
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-4, -2]
) # Average over kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
else: else:
# Calculate actual stride to fit exactly
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = [] values = []
for i in range(output_H): for i in range(output_H):
row_values = [] row_values = []
@@ -492,32 +477,37 @@ def adaptive_avg_pool2d(
w_start = j * stride_W w_start = j * stride_W
w_end = min(w_start + kernel_W, W) w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., h_start:h_end, w_start:w_end, :] region = x[..., h_start:h_end, w_start:w_end, :]
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W row_values.append(mx.mean(region, axis=(-3, -2)))
row_values.append(avg_val) values.append(mx.stack(row_values, axis=-2))
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
# Stack all rows along H dimension return mx.stack(values, axis=-3)
result = mx.stack(values, axis=-3)
return result
class AdaptiveAvgPool3d(Module): class AdaptiveAvgPool3d(_AdaptivePool):
r"""Applies 3-dimensional adaptive average pooling. r"""Applies 3-dimensional adaptive average pooling.
The output size is D x H x W, for any input size. The number of output Spatially downsamples the input by taking the average over pooling regions
features is equal to the number of input planes. such that the output size is D x H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for the depth,
height, and width dimensions, creating a cube output.
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth dimension, the second ``int`` for the height dimension, and
the third ``int`` for the width dimension.
* ``None`` can be used for any dimension to keep the input size unchanged.
Args: Args:
output_size: the target output size of the form D x H x W. output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
Can be a tuple (D, H, W) or a single int for a cube D x D x D. Can be a tuple (D, H, W) or a single int for a cube output.
D, H and W can be either an ``int``, or ``None`` which means the size D, H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input. will be the same as that of the input.
Examples: Examples:
>>> import mlx.core as mx >>> import mlx.core as mx
>>> import mlx.nn as nn >>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9)) >>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
>>> pool(x) >>> pool(x)
@@ -529,28 +519,10 @@ class AdaptiveAvgPool3d(Module):
self, self,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]], output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
): ):
super().__init__() super().__init__(output_size)
self.output_size = output_size
def __call__(self, x): def __call__(self, x):
return adaptive_avg_pool3d(x, self.output_size) output_size = self.output_size
def adaptive_avg_pool3d(
x: mx.array,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
) -> mx.array:
r"""Apply 3-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
output_size: Target output size (D, H, W) or single int for cube output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = (output_size, output_size, output_size) output_size = (output_size, output_size, output_size)
elif len(output_size) == 1: elif len(output_size) == 1:
@@ -558,26 +530,20 @@ def adaptive_avg_pool3d(
elif len(output_size) == 2: elif len(output_size) == 2:
output_size = (output_size[0], output_size[1], output_size[1]) output_size = (output_size[0], output_size[1], output_size[1])
# Get input dimensions
*batch_dims, D, H, W, C = x.shape *batch_dims, D, H, W, C = x.shape
# Handle None values in output_size
output_D = D if output_size[0] is None else output_size[0] output_D = D if output_size[0] is None else output_size[0]
output_H = H if output_size[1] is None else output_size[1] output_H = H if output_size[1] is None else output_size[1]
output_W = W if output_size[2] is None else output_size[2] output_W = W if output_size[2] is None else output_size[2]
# If already the right size, return as is
if D == output_D and H == output_H and W == output_W: if D == output_D and H == output_H and W == output_W:
return x return x
# Calculate kernel size and stride
kernel_D = D // output_D kernel_D = D // output_D
kernel_H = H // output_H kernel_H = H // output_H
kernel_W = W // output_W kernel_W = W // output_W
# For exact division, use regular pooling
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0: if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [ new_shape = batch_dims + [
output_D, output_D,
kernel_D, kernel_D,
@@ -588,21 +554,12 @@ def adaptive_avg_pool3d(
C, C,
] ]
x_reshaped = x.reshape(new_shape) x_reshaped = x.reshape(new_shape)
return mx.mean(x_reshaped, axis=[-6, -4, -2])
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-6, -4, -2]
) # Average over kernel_D, kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
else: else:
# Calculate actual stride to fit exactly
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1 stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1 stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1 stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = [] values = []
for i in range(output_D): for i in range(output_D):
depth_values = [] depth_values = []
@@ -616,17 +573,9 @@ def adaptive_avg_pool3d(
w_start = k * stride_W w_start = k * stride_W
w_end = min(w_start + kernel_W, W) w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :] region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
avg_val = mx.mean( row_values.append(mx.mean(region, axis=(-4, -3, -2)))
region, axis=(-4, -3, -2) depth_values.append(mx.stack(row_values, axis=-2))
) # Average over D, H and W values.append(mx.stack(depth_values, axis=-3))
row_values.append(avg_val)
depth_values.append(
mx.stack(row_values, axis=-2)
) # Stack along W dimension
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
# Stack all depths along D dimension return mx.stack(values, axis=-4)
result = mx.stack(values, axis=-4)
return result