Add AdaptiveAvgPool2d and AdaptiveAvgPool3d

This commit is contained in:
Vincent Amato
2025-08-11 21:17:49 -04:00
parent 7fde1b6a1e
commit 634ce07a3e
4 changed files with 310 additions and 0 deletions

View File

@@ -77,6 +77,8 @@ from mlx.nn.layers.normalization import (
RMSNorm,
)
from mlx.nn.layers.pooling import (
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
AvgPool1d,
AvgPool2d,
AvgPool3d,

View File

@@ -396,3 +396,237 @@ class AvgPool3d(_Pool3d):
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
class AdaptiveAvgPool2d(Module):
r"""Applies 2-dimensional adaptive average pooling.
The output size is H x W, for any input size. The number of output
features is equal to the number of input planes.
Args:
output_size: the target output size of the form H x W.
Can be a tuple (H, W) or a single int for a square image.
H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
>>> pool(x)
>>> pool = nn.AdaptiveAvgPool2d(7)
>>> pool(x)
"""
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
super().__init__()
self.output_size = output_size
def __call__(self, x):
return adaptive_avg_pool2d(x, self.output_size)
def adaptive_avg_pool2d(
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
) -> mx.array:
r"""Apply 2-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, H, W, C) or (H, W, C).
output_size: Target output size (H, W) or single int for square output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
if isinstance(output_size, int):
output_size = (output_size, output_size)
elif len(output_size) == 1:
output_size = (output_size[0], output_size[0])
# Get input dimensions
*batch_dims, H, W, C = x.shape
# Handle None values in output_size
output_H = H if output_size[0] is None else output_size[0]
output_W = W if output_size[1] is None else output_size[1]
# If already the right size, return as is
if H == output_H and W == output_W:
return x
# Calculate kernel size and stride
kernel_H = H // output_H
kernel_W = W // output_W
stride_H = H // output_H
stride_W = W // output_W
# For exact division, use regular pooling
if H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
x_reshaped = x.reshape(new_shape)
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-4, -2]
) # Average over kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
else:
# Calculate actual stride to fit exactly
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = []
for i in range(output_H):
row_values = []
for j in range(output_W):
h_start = i * stride_H
h_end = min(h_start + kernel_H, H)
w_start = j * stride_W
w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., h_start:h_end, w_start:w_end, :]
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W
row_values.append(avg_val)
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
# Stack all rows along H dimension
result = mx.stack(values, axis=-3)
return result
class AdaptiveAvgPool3d(Module):
r"""Applies 3-dimensional adaptive average pooling.
The output size is D x H x W, for any input size. The number of output
features is equal to the number of input planes.
Args:
output_size: the target output size of the form D x H x W.
Can be a tuple (D, H, W) or a single int for a cube D x D x D.
D, H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
>>> pool(x)
>>> pool = nn.AdaptiveAvgPool3d(7)
>>> pool(x)
"""
def __init__(
self,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
):
super().__init__()
self.output_size = output_size
def __call__(self, x):
return adaptive_avg_pool3d(x, self.output_size)
def adaptive_avg_pool3d(
x: mx.array,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
) -> mx.array:
r"""Apply 3-dimensional adaptive average pooling.
Args:
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
output_size: Target output size (D, H, W) or single int for cube output.
Values can be None to keep the corresponding input dimension.
Returns:
Output array with spatial dimensions matching output_size.
"""
# Parse output_size
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)
elif len(output_size) == 1:
output_size = (output_size[0], output_size[0], output_size[0])
elif len(output_size) == 2:
output_size = (output_size[0], output_size[1], output_size[1])
# Get input dimensions
*batch_dims, D, H, W, C = x.shape
# Handle None values in output_size
output_D = D if output_size[0] is None else output_size[0]
output_H = H if output_size[1] is None else output_size[1]
output_W = W if output_size[2] is None else output_size[2]
# If already the right size, return as is
if D == output_D and H == output_H and W == output_W:
return x
# Calculate kernel size and stride
kernel_D = D // output_D
kernel_H = H // output_H
kernel_W = W // output_W
# For exact division, use regular pooling
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
new_shape = batch_dims + [
output_D,
kernel_D,
output_H,
kernel_H,
output_W,
kernel_W,
C,
]
x_reshaped = x.reshape(new_shape)
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
result = mx.mean(
x_reshaped, axis=[-6, -4, -2]
) # Average over kernel_D, kernel_H and kernel_W
return result
# For non-exact division, use strided approach with overlap
else:
# Calculate actual stride to fit exactly
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
# Collect all averaged values
values = []
for i in range(output_D):
depth_values = []
for j in range(output_H):
row_values = []
for k in range(output_W):
d_start = i * stride_D
d_end = min(d_start + kernel_D, D)
h_start = j * stride_H
h_end = min(h_start + kernel_H, H)
w_start = k * stride_W
w_end = min(w_start + kernel_W, W)
# Extract region and average
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
avg_val = mx.mean(
region, axis=(-4, -3, -2)
) # Average over D, H and W
row_values.append(avg_val)
depth_values.append(
mx.stack(row_values, axis=-2)
) # Stack along W dimension
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
# Stack all depths along D dimension
result = mx.stack(values, axis=-4)
return result

View File

@@ -1818,6 +1818,79 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
# Test AdaptiveAvgPool2d
# Test exact division case (8x8 -> 4x4)
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
pool = nn.AdaptiveAvgPool2d((4, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 4, 1))
# Check first few values manually
expected_00 = (0 + 1 + 8 + 9) / 4 # Top-left 2x2 block
expected_01 = (2 + 3 + 10 + 11) / 4 # Next 2x2 block
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
self.assertAlmostEqual(output[0, 0, 1, 0].item(), expected_01, places=5)
# Test square output format (int input)
pool = nn.AdaptiveAvgPool2d(2)
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 1))
# Test None values (keep original dimension)
x = mx.random.normal((2, 6, 8, 3))
pool = nn.AdaptiveAvgPool2d((3, None))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 3))
pool = nn.AdaptiveAvgPool2d((None, 4))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 2))
pool = nn.AdaptiveAvgPool2d((3, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 2))
# Test 3D input (no batch dimension)
x = mx.random.normal((6, 6, 4))
pool = nn.AdaptiveAvgPool2d((2, 3))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 4))
# Test AdaptiveAvgPool3d
# Test exact division case (8x8x8 -> 4x4x4)
x = mx.arange(512, dtype=mx.float32).reshape(1, 8, 8, 8, 1)
pool = nn.AdaptiveAvgPool3d((4, 4, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 4, 4, 1))
# Test cube output format (int input)
pool = nn.AdaptiveAvgPool3d(2)
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
# Test None values (keep original dimensions)
x = mx.random.normal((2, 6, 8, 10, 3))
pool = nn.AdaptiveAvgPool3d((3, None, 5))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
pool = nn.AdaptiveAvgPool3d((None, 4, None))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 9, 2))
pool = nn.AdaptiveAvgPool3d((3, 2, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
# Test 4D input (no batch dimension)
x = mx.random.normal((6, 6, 6, 4))
pool = nn.AdaptiveAvgPool3d((2, 3, 2))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 2, 4))
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):