Add AdaptiveMaxPool1d, AdaptiveMaxPool2d, and AdaptiveMaxPool3d layers

This commit is contained in:
Vincent Amato
2025-08-11 23:45:52 -04:00
parent 7fde1b6a1e
commit 68c7be55bb
4 changed files with 396 additions and 0 deletions

View File

@@ -9,6 +9,9 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
AdaptiveMaxPool1d
AdaptiveMaxPool2d
AdaptiveMaxPool3d
ALiBi
AvgPool1d
AvgPool2d

View File

@@ -77,6 +77,9 @@ from mlx.nn.layers.normalization import (
RMSNorm,
)
from mlx.nn.layers.pooling import (
AdaptiveMaxPool1d,
AdaptiveMaxPool2d,
AdaptiveMaxPool3d,
AvgPool1d,
AvgPool2d,
AvgPool3d,

View File

@@ -396,3 +396,266 @@ class AvgPool3d(_Pool3d):
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
class _AdaptivePool(Module):
"""Base class for adaptive pooling layers."""
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
def _extra_repr(self):
return f"output_size={self.output_size}"
class AdaptiveMaxPool1d(_AdaptivePool):
r"""Applies 1-dimensional adaptive max pooling.
Spatially downsamples the input by taking the maximum over pooling regions
such that the output size is L, for any input size.
The parameter can be:
* a single ``int`` -- the target output size.
* ``None`` can be used to keep the input size unchanged.
Args:
output_size (int or None): The target output size L.
Can be an ``int``, or ``None`` which means the size
will be the same as that of the input.
Note:
Unlike PyTorch's implementation, this layer does not support
returning indices (`return_indices` parameter). This may be
added in a future version.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 4))
>>> pool = nn.AdaptiveMaxPool1d(7)
>>> pool(x)
>>> pool = nn.AdaptiveMaxPool1d(None) # No change
>>> pool(x)
"""
def __init__(self, output_size: Union[int, None]):
super().__init__(output_size)
def __call__(self, x):
output_size = self.output_size
*batch_dims, L, C = x.shape
output_L = L if output_size is None else output_size
if L == output_L:
return x
kernel_L = max(1, L // output_L)
if L % output_L == 0 and kernel_L > 0:
new_shape = batch_dims + [output_L, kernel_L, C]
x_reshaped = x.reshape(new_shape)
return mx.max(x_reshaped, axis=-2)
else:
stride_L = max(1, (L - kernel_L) // (output_L - 1)) if output_L > 1 else 1
values = []
for i in range(output_L):
l_start = min(i * stride_L, L - 1)
l_end = min(l_start + kernel_L, L)
region = x[..., l_start:l_end, :]
values.append(mx.max(region, axis=-2))
return mx.stack(values, axis=-2)
class AdaptiveMaxPool2d(_AdaptivePool):
r"""Applies 2-dimensional adaptive max pooling.
Spatially downsamples the input by taking the maximum over pooling regions
such that the output size is H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for both the
height and width dimensions, creating a square output.
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height dimension, the second ``int`` for the width dimension.
* ``None`` can be used for either dimension to keep the input size unchanged.
Args:
output_size (int or tuple(int, int)): The target output size of the form H x W.
Can be a tuple (H, W) or a single int for a square output.
H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Note:
Unlike PyTorch's implementation, this layer does not support
returning indices (`return_indices` parameter). This may be
added in a future version.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.AdaptiveMaxPool2d((5, 7))
>>> pool(x)
>>> pool = nn.AdaptiveMaxPool2d(7)
>>> pool(x)
"""
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
super().__init__(output_size)
def __call__(self, x):
output_size = self.output_size
if isinstance(output_size, int):
output_size = (output_size, output_size)
elif len(output_size) == 1:
output_size = (output_size[0], output_size[0])
*batch_dims, H, W, C = x.shape
output_H = H if output_size[0] is None else output_size[0]
output_W = W if output_size[1] is None else output_size[1]
if H == output_H and W == output_W:
return x
kernel_H = max(1, H // output_H)
kernel_W = max(1, W // output_W)
if H % output_H == 0 and W % output_W == 0 and kernel_H > 0 and kernel_W > 0:
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
x_reshaped = x.reshape(new_shape)
return mx.max(x_reshaped, axis=[-4, -2])
else:
stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1
stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1
values = []
for i in range(output_H):
row_values = []
for j in range(output_W):
h_start = min(i * stride_H, H - 1)
h_end = min(h_start + kernel_H, H)
w_start = min(j * stride_W, W - 1)
w_end = min(w_start + kernel_W, W)
region = x[..., h_start:h_end, w_start:w_end, :]
row_values.append(mx.max(region, axis=(-3, -2)))
values.append(mx.stack(row_values, axis=-2))
return mx.stack(values, axis=-3)
class AdaptiveMaxPool3d(_AdaptivePool):
r"""Applies 3-dimensional adaptive max pooling.
Spatially downsamples the input by taking the maximum over pooling regions
such that the output size is D x H x W, for any input size.
The parameters can be:
* a single ``int`` -- in which case the same value is used for the depth,
height, and width dimensions, creating a cube output.
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth dimension, the second ``int`` for the height dimension, and
the third ``int`` for the width dimension.
* ``None`` can be used for any dimension to keep the input size unchanged.
Args:
output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
Can be a tuple (D, H, W) or a single int for a cube output.
D, H and W can be either an ``int``, or ``None`` which means the size
will be the same as that of the input.
Note:
Unlike PyTorch's implementation, this layer does not support
returning indices (`return_indices` parameter). This may be
added in a future version.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AdaptiveMaxPool3d((5, 7, 9))
>>> pool(x)
>>> pool = nn.AdaptiveMaxPool3d(7)
>>> pool(x)
"""
def __init__(
self,
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
):
super().__init__(output_size)
def __call__(self, x):
output_size = self.output_size
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)
elif len(output_size) == 1:
output_size = (output_size[0], output_size[0], output_size[0])
elif len(output_size) == 2:
output_size = (output_size[0], output_size[1], output_size[1])
*batch_dims, D, H, W, C = x.shape
output_D = D if output_size[0] is None else output_size[0]
output_H = H if output_size[1] is None else output_size[1]
output_W = W if output_size[2] is None else output_size[2]
if D == output_D and H == output_H and W == output_W:
return x
kernel_D = max(1, D // output_D)
kernel_H = max(1, H // output_H)
kernel_W = max(1, W // output_W)
if (
D % output_D == 0
and H % output_H == 0
and W % output_W == 0
and kernel_D > 0
and kernel_H > 0
and kernel_W > 0
):
new_shape = batch_dims + [
output_D,
kernel_D,
output_H,
kernel_H,
output_W,
kernel_W,
C,
]
x_reshaped = x.reshape(new_shape)
return mx.max(x_reshaped, axis=[-6, -4, -2])
else:
stride_D = max(1, (D - kernel_D) // (output_D - 1)) if output_D > 1 else 1
stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1
stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1
values = []
for i in range(output_D):
depth_values = []
for j in range(output_H):
row_values = []
for k in range(output_W):
d_start = min(i * stride_D, D - 1)
d_end = min(d_start + kernel_D, D)
h_start = min(j * stride_H, H - 1)
h_end = min(h_start + kernel_H, H)
w_start = min(k * stride_W, W - 1)
w_end = min(w_start + kernel_W, W)
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
row_values.append(mx.max(region, axis=(-4, -3, -2)))
depth_values.append(mx.stack(row_values, axis=-2))
values.append(mx.stack(depth_values, axis=-3))
return mx.stack(values, axis=-4)

View File

@@ -1818,6 +1818,133 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
# Test AdaptiveMaxPool1d
# Test exact division case (8 -> 4)
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
pool = nn.AdaptiveMaxPool1d(4)
output = pool(x)
self.assertEqual(output.shape, (1, 4, 1))
# Check first few values manually
expected_0 = max(0, 1) # First 2 elements max
expected_1 = max(2, 3) # Next 2 elements max
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
# Test None value (keep original dimension)
x = mx.random.normal((2, 8, 3))
pool = nn.AdaptiveMaxPool1d(None)
output = pool(x)
self.assertEqual(output.shape, (2, 8, 3))
# Should be identical to input when output_size is None
self.assertTrue(mx.allclose(output, x))
# Test non-exact division case
x = mx.random.normal((1, 7, 2))
pool = nn.AdaptiveMaxPool1d(3)
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2))
# Test 2D input (no batch dimension)
x = mx.random.normal((6, 4))
pool = nn.AdaptiveMaxPool1d(2)
output = pool(x)
self.assertEqual(output.shape, (2, 4))
# Test single output
x = mx.random.normal((1, 10, 2))
pool = nn.AdaptiveMaxPool1d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 2))
# Test larger output than input
x = mx.random.normal((1, 3, 2))
pool = nn.AdaptiveMaxPool1d(5)
output = pool(x)
self.assertEqual(output.shape, (1, 5, 2))
# Test AdaptiveMaxPool2d
# Test exact division case (4x4 -> 2x2)
x = mx.arange(16, dtype=mx.float32).reshape(1, 4, 4, 1)
pool = nn.AdaptiveMaxPool2d((2, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 1))
# Check first value manually - max of top-left 2x2 block
expected_00 = max(0, 1, 4, 5) # Top-left 2x2 block
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
# Test square output format (int input)
pool = nn.AdaptiveMaxPool2d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 1, 1))
# Test None values (keep original dimension)
x = mx.random.normal((2, 6, 8, 3))
pool = nn.AdaptiveMaxPool2d((3, None))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 3))
pool = nn.AdaptiveMaxPool2d((None, 4))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 2))
pool = nn.AdaptiveMaxPool2d((3, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 2))
# Test larger output than input
x = mx.random.normal((1, 2, 3, 2))
pool = nn.AdaptiveMaxPool2d((4, 5))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 5, 2))
# Test 3D input (no batch dimension)
x = mx.random.normal((6, 6, 4))
pool = nn.AdaptiveMaxPool2d((2, 3))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 4))
# Test AdaptiveMaxPool3d
# Test exact division case (4x4x4 -> 2x2x2)
x = mx.arange(64, dtype=mx.float32).reshape(1, 4, 4, 4, 1)
pool = nn.AdaptiveMaxPool3d((2, 2, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
# Test cube output format (int input)
pool = nn.AdaptiveMaxPool3d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 1, 1, 1))
# Test None values (keep original dimensions)
x = mx.random.normal((2, 6, 8, 10, 3))
pool = nn.AdaptiveMaxPool3d((3, None, 5))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
pool = nn.AdaptiveMaxPool3d((None, 4, None))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 9, 2))
pool = nn.AdaptiveMaxPool3d((3, 2, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
# Test larger output than input
x = mx.random.normal((1, 2, 3, 2, 2))
pool = nn.AdaptiveMaxPool3d((4, 5, 3))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 5, 3, 2))
# Test 4D input (no batch dimension)
x = mx.random.normal((6, 6, 6, 4))
pool = nn.AdaptiveMaxPool3d((2, 3, 2))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 2, 4))
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):