mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor adaptive pooling for style consistency
This commit is contained in:
@@ -398,21 +398,40 @@ class AvgPool3d(_Pool3d):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(Module):
|
||||
class _AdaptivePool(Module):
|
||||
"""Base class for adaptive pooling layers."""
|
||||
|
||||
def __init__(self, output_size):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"output_size={self.output_size}"
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(_AdaptivePool):
|
||||
r"""Applies 2-dimensional adaptive average pooling.
|
||||
|
||||
The output size is H x W, for any input size. The number of output
|
||||
features is equal to the number of input planes.
|
||||
Spatially downsamples the input by taking the average over pooling regions
|
||||
such that the output size is H x W, for any input size.
|
||||
|
||||
The parameters can be:
|
||||
|
||||
* a single ``int`` -- in which case the same value is used for both the
|
||||
height and width dimensions, creating a square output.
|
||||
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height dimension, the second ``int`` for the width dimension.
|
||||
* ``None`` can be used for either dimension to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size: the target output size of the form H x W.
|
||||
Can be a tuple (H, W) or a single int for a square image.
|
||||
output_size (int or tuple(int, int)): The target output size of the form H x W.
|
||||
Can be a tuple (H, W) or a single int for a square output.
|
||||
H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
|
||||
>>> pool(x)
|
||||
@@ -421,68 +440,34 @@ class AdaptiveAvgPool2d(Module):
|
||||
"""
|
||||
|
||||
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
return adaptive_avg_pool2d(x, self.output_size)
|
||||
|
||||
|
||||
def adaptive_avg_pool2d(
|
||||
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
|
||||
) -> mx.array:
|
||||
r"""Apply 2-dimensional adaptive average pooling.
|
||||
|
||||
Args:
|
||||
x: Input array of shape (N, H, W, C) or (H, W, C).
|
||||
output_size: Target output size (H, W) or single int for square output.
|
||||
Values can be None to keep the corresponding input dimension.
|
||||
|
||||
Returns:
|
||||
Output array with spatial dimensions matching output_size.
|
||||
"""
|
||||
# Parse output_size
|
||||
output_size = self.output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0])
|
||||
|
||||
# Get input dimensions
|
||||
*batch_dims, H, W, C = x.shape
|
||||
|
||||
# Handle None values in output_size
|
||||
output_H = H if output_size[0] is None else output_size[0]
|
||||
output_W = W if output_size[1] is None else output_size[1]
|
||||
|
||||
# If already the right size, return as is
|
||||
if H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
# Calculate kernel size and stride
|
||||
kernel_H = H // output_H
|
||||
kernel_W = W // output_W
|
||||
stride_H = H // output_H
|
||||
stride_W = W // output_W
|
||||
|
||||
# For exact division, use regular pooling
|
||||
if H % output_H == 0 and W % output_W == 0:
|
||||
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
|
||||
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
|
||||
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
|
||||
result = mx.mean(
|
||||
x_reshaped, axis=[-4, -2]
|
||||
) # Average over kernel_H and kernel_W
|
||||
return result
|
||||
|
||||
# For non-exact division, use strided approach with overlap
|
||||
return mx.mean(x_reshaped, axis=[-4, -2])
|
||||
else:
|
||||
# Calculate actual stride to fit exactly
|
||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||
|
||||
# Collect all averaged values
|
||||
values = []
|
||||
for i in range(output_H):
|
||||
row_values = []
|
||||
@@ -492,32 +477,37 @@ def adaptive_avg_pool2d(
|
||||
w_start = j * stride_W
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
# Extract region and average
|
||||
region = x[..., h_start:h_end, w_start:w_end, :]
|
||||
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W
|
||||
row_values.append(avg_val)
|
||||
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
|
||||
row_values.append(mx.mean(region, axis=(-3, -2)))
|
||||
values.append(mx.stack(row_values, axis=-2))
|
||||
|
||||
# Stack all rows along H dimension
|
||||
result = mx.stack(values, axis=-3)
|
||||
return result
|
||||
return mx.stack(values, axis=-3)
|
||||
|
||||
|
||||
class AdaptiveAvgPool3d(Module):
|
||||
class AdaptiveAvgPool3d(_AdaptivePool):
|
||||
r"""Applies 3-dimensional adaptive average pooling.
|
||||
|
||||
The output size is D x H x W, for any input size. The number of output
|
||||
features is equal to the number of input planes.
|
||||
Spatially downsamples the input by taking the average over pooling regions
|
||||
such that the output size is D x H x W, for any input size.
|
||||
|
||||
The parameters can be:
|
||||
|
||||
* a single ``int`` -- in which case the same value is used for the depth,
|
||||
height, and width dimensions, creating a cube output.
|
||||
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
|
||||
for the depth dimension, the second ``int`` for the height dimension, and
|
||||
the third ``int`` for the width dimension.
|
||||
* ``None`` can be used for any dimension to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size: the target output size of the form D x H x W.
|
||||
Can be a tuple (D, H, W) or a single int for a cube D x D x D.
|
||||
output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
|
||||
Can be a tuple (D, H, W) or a single int for a cube output.
|
||||
D, H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
|
||||
>>> pool(x)
|
||||
@@ -529,28 +519,10 @@ class AdaptiveAvgPool3d(Module):
|
||||
self,
|
||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||
):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
return adaptive_avg_pool3d(x, self.output_size)
|
||||
|
||||
|
||||
def adaptive_avg_pool3d(
|
||||
x: mx.array,
|
||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||
) -> mx.array:
|
||||
r"""Apply 3-dimensional adaptive average pooling.
|
||||
|
||||
Args:
|
||||
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
|
||||
output_size: Target output size (D, H, W) or single int for cube output.
|
||||
Values can be None to keep the corresponding input dimension.
|
||||
|
||||
Returns:
|
||||
Output array with spatial dimensions matching output_size.
|
||||
"""
|
||||
# Parse output_size
|
||||
output_size = self.output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
@@ -558,26 +530,20 @@ def adaptive_avg_pool3d(
|
||||
elif len(output_size) == 2:
|
||||
output_size = (output_size[0], output_size[1], output_size[1])
|
||||
|
||||
# Get input dimensions
|
||||
*batch_dims, D, H, W, C = x.shape
|
||||
|
||||
# Handle None values in output_size
|
||||
output_D = D if output_size[0] is None else output_size[0]
|
||||
output_H = H if output_size[1] is None else output_size[1]
|
||||
output_W = W if output_size[2] is None else output_size[2]
|
||||
|
||||
# If already the right size, return as is
|
||||
if D == output_D and H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
# Calculate kernel size and stride
|
||||
kernel_D = D // output_D
|
||||
kernel_H = H // output_H
|
||||
kernel_W = W // output_W
|
||||
|
||||
# For exact division, use regular pooling
|
||||
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
|
||||
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
|
||||
new_shape = batch_dims + [
|
||||
output_D,
|
||||
kernel_D,
|
||||
@@ -588,21 +554,12 @@ def adaptive_avg_pool3d(
|
||||
C,
|
||||
]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
|
||||
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
|
||||
result = mx.mean(
|
||||
x_reshaped, axis=[-6, -4, -2]
|
||||
) # Average over kernel_D, kernel_H and kernel_W
|
||||
return result
|
||||
|
||||
# For non-exact division, use strided approach with overlap
|
||||
return mx.mean(x_reshaped, axis=[-6, -4, -2])
|
||||
else:
|
||||
# Calculate actual stride to fit exactly
|
||||
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
|
||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||
|
||||
# Collect all averaged values
|
||||
values = []
|
||||
for i in range(output_D):
|
||||
depth_values = []
|
||||
@@ -616,17 +573,9 @@ def adaptive_avg_pool3d(
|
||||
w_start = k * stride_W
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
# Extract region and average
|
||||
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
|
||||
avg_val = mx.mean(
|
||||
region, axis=(-4, -3, -2)
|
||||
) # Average over D, H and W
|
||||
row_values.append(avg_val)
|
||||
depth_values.append(
|
||||
mx.stack(row_values, axis=-2)
|
||||
) # Stack along W dimension
|
||||
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
|
||||
row_values.append(mx.mean(region, axis=(-4, -3, -2)))
|
||||
depth_values.append(mx.stack(row_values, axis=-2))
|
||||
values.append(mx.stack(depth_values, axis=-3))
|
||||
|
||||
# Stack all depths along D dimension
|
||||
result = mx.stack(values, axis=-4)
|
||||
return result
|
||||
return mx.stack(values, axis=-4)
|
||||
|
||||
Reference in New Issue
Block a user