mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add AdaptiveMaxPool1d, AdaptiveMaxPool2d, and AdaptiveMaxPool3d layers
This commit is contained in:
@@ -9,6 +9,9 @@ Layers
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
AdaptiveMaxPool1d
|
||||
AdaptiveMaxPool2d
|
||||
AdaptiveMaxPool3d
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
|
@@ -77,6 +77,9 @@ from mlx.nn.layers.normalization import (
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.pooling import (
|
||||
AdaptiveMaxPool1d,
|
||||
AdaptiveMaxPool2d,
|
||||
AdaptiveMaxPool3d,
|
||||
AvgPool1d,
|
||||
AvgPool2d,
|
||||
AvgPool3d,
|
||||
|
@@ -396,3 +396,266 @@ class AvgPool3d(_Pool3d):
|
||||
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
class _AdaptivePool(Module):
|
||||
"""Base class for adaptive pooling layers."""
|
||||
|
||||
def __init__(self, output_size):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"output_size={self.output_size}"
|
||||
|
||||
|
||||
class AdaptiveMaxPool1d(_AdaptivePool):
|
||||
r"""Applies 1-dimensional adaptive max pooling.
|
||||
|
||||
Spatially downsamples the input by taking the maximum over pooling regions
|
||||
such that the output size is L, for any input size.
|
||||
|
||||
The parameter can be:
|
||||
|
||||
* a single ``int`` -- the target output size.
|
||||
* ``None`` can be used to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size (int or None): The target output size L.
|
||||
Can be an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Note:
|
||||
Unlike PyTorch's implementation, this layer does not support
|
||||
returning indices (`return_indices` parameter). This may be
|
||||
added in a future version.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 4))
|
||||
>>> pool = nn.AdaptiveMaxPool1d(7)
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveMaxPool1d(None) # No change
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(self, output_size: Union[int, None]):
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
output_size = self.output_size
|
||||
|
||||
*batch_dims, L, C = x.shape
|
||||
|
||||
output_L = L if output_size is None else output_size
|
||||
|
||||
if L == output_L:
|
||||
return x
|
||||
|
||||
kernel_L = max(1, L // output_L)
|
||||
|
||||
if L % output_L == 0 and kernel_L > 0:
|
||||
new_shape = batch_dims + [output_L, kernel_L, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
return mx.max(x_reshaped, axis=-2)
|
||||
else:
|
||||
stride_L = max(1, (L - kernel_L) // (output_L - 1)) if output_L > 1 else 1
|
||||
|
||||
values = []
|
||||
for i in range(output_L):
|
||||
l_start = min(i * stride_L, L - 1)
|
||||
l_end = min(l_start + kernel_L, L)
|
||||
region = x[..., l_start:l_end, :]
|
||||
values.append(mx.max(region, axis=-2))
|
||||
|
||||
return mx.stack(values, axis=-2)
|
||||
|
||||
|
||||
class AdaptiveMaxPool2d(_AdaptivePool):
|
||||
r"""Applies 2-dimensional adaptive max pooling.
|
||||
|
||||
Spatially downsamples the input by taking the maximum over pooling regions
|
||||
such that the output size is H x W, for any input size.
|
||||
|
||||
The parameters can be:
|
||||
|
||||
* a single ``int`` -- in which case the same value is used for both the
|
||||
height and width dimensions, creating a square output.
|
||||
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height dimension, the second ``int`` for the width dimension.
|
||||
* ``None`` can be used for either dimension to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size (int or tuple(int, int)): The target output size of the form H x W.
|
||||
Can be a tuple (H, W) or a single int for a square output.
|
||||
H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Note:
|
||||
Unlike PyTorch's implementation, this layer does not support
|
||||
returning indices (`return_indices` parameter). This may be
|
||||
added in a future version.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveMaxPool2d((5, 7))
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveMaxPool2d(7)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
output_size = self.output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0])
|
||||
|
||||
*batch_dims, H, W, C = x.shape
|
||||
|
||||
output_H = H if output_size[0] is None else output_size[0]
|
||||
output_W = W if output_size[1] is None else output_size[1]
|
||||
|
||||
if H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
kernel_H = max(1, H // output_H)
|
||||
kernel_W = max(1, W // output_W)
|
||||
|
||||
if H % output_H == 0 and W % output_W == 0 and kernel_H > 0 and kernel_W > 0:
|
||||
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
return mx.max(x_reshaped, axis=[-4, -2])
|
||||
else:
|
||||
stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1
|
||||
stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1
|
||||
|
||||
values = []
|
||||
for i in range(output_H):
|
||||
row_values = []
|
||||
for j in range(output_W):
|
||||
h_start = min(i * stride_H, H - 1)
|
||||
h_end = min(h_start + kernel_H, H)
|
||||
w_start = min(j * stride_W, W - 1)
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
region = x[..., h_start:h_end, w_start:w_end, :]
|
||||
row_values.append(mx.max(region, axis=(-3, -2)))
|
||||
values.append(mx.stack(row_values, axis=-2))
|
||||
|
||||
return mx.stack(values, axis=-3)
|
||||
|
||||
|
||||
class AdaptiveMaxPool3d(_AdaptivePool):
|
||||
r"""Applies 3-dimensional adaptive max pooling.
|
||||
|
||||
Spatially downsamples the input by taking the maximum over pooling regions
|
||||
such that the output size is D x H x W, for any input size.
|
||||
|
||||
The parameters can be:
|
||||
|
||||
* a single ``int`` -- in which case the same value is used for the depth,
|
||||
height, and width dimensions, creating a cube output.
|
||||
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
|
||||
for the depth dimension, the second ``int`` for the height dimension, and
|
||||
the third ``int`` for the width dimension.
|
||||
* ``None`` can be used for any dimension to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
|
||||
Can be a tuple (D, H, W) or a single int for a cube output.
|
||||
D, H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Note:
|
||||
Unlike PyTorch's implementation, this layer does not support
|
||||
returning indices (`return_indices` parameter). This may be
|
||||
added in a future version.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveMaxPool3d((5, 7, 9))
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveMaxPool3d(7)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||
):
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
output_size = self.output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0], output_size[0])
|
||||
elif len(output_size) == 2:
|
||||
output_size = (output_size[0], output_size[1], output_size[1])
|
||||
|
||||
*batch_dims, D, H, W, C = x.shape
|
||||
|
||||
output_D = D if output_size[0] is None else output_size[0]
|
||||
output_H = H if output_size[1] is None else output_size[1]
|
||||
output_W = W if output_size[2] is None else output_size[2]
|
||||
|
||||
if D == output_D and H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
kernel_D = max(1, D // output_D)
|
||||
kernel_H = max(1, H // output_H)
|
||||
kernel_W = max(1, W // output_W)
|
||||
|
||||
if (
|
||||
D % output_D == 0
|
||||
and H % output_H == 0
|
||||
and W % output_W == 0
|
||||
and kernel_D > 0
|
||||
and kernel_H > 0
|
||||
and kernel_W > 0
|
||||
):
|
||||
new_shape = batch_dims + [
|
||||
output_D,
|
||||
kernel_D,
|
||||
output_H,
|
||||
kernel_H,
|
||||
output_W,
|
||||
kernel_W,
|
||||
C,
|
||||
]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
return mx.max(x_reshaped, axis=[-6, -4, -2])
|
||||
else:
|
||||
stride_D = max(1, (D - kernel_D) // (output_D - 1)) if output_D > 1 else 1
|
||||
stride_H = max(1, (H - kernel_H) // (output_H - 1)) if output_H > 1 else 1
|
||||
stride_W = max(1, (W - kernel_W) // (output_W - 1)) if output_W > 1 else 1
|
||||
|
||||
values = []
|
||||
for i in range(output_D):
|
||||
depth_values = []
|
||||
for j in range(output_H):
|
||||
row_values = []
|
||||
for k in range(output_W):
|
||||
d_start = min(i * stride_D, D - 1)
|
||||
d_end = min(d_start + kernel_D, D)
|
||||
h_start = min(j * stride_H, H - 1)
|
||||
h_end = min(h_start + kernel_H, H)
|
||||
w_start = min(k * stride_W, W - 1)
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
|
||||
row_values.append(mx.max(region, axis=(-4, -3, -2)))
|
||||
depth_values.append(mx.stack(row_values, axis=-2))
|
||||
values.append(mx.stack(depth_values, axis=-3))
|
||||
|
||||
return mx.stack(values, axis=-4)
|
||||
|
@@ -1818,6 +1818,133 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
# Test AdaptiveMaxPool1d
|
||||
# Test exact division case (8 -> 4)
|
||||
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
|
||||
pool = nn.AdaptiveMaxPool1d(4)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 1))
|
||||
# Check first few values manually
|
||||
expected_0 = max(0, 1) # First 2 elements max
|
||||
expected_1 = max(2, 3) # Next 2 elements max
|
||||
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
|
||||
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
|
||||
|
||||
# Test None value (keep original dimension)
|
||||
x = mx.random.normal((2, 8, 3))
|
||||
pool = nn.AdaptiveMaxPool1d(None)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 8, 3))
|
||||
# Should be identical to input when output_size is None
|
||||
self.assertTrue(mx.allclose(output, x))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 2))
|
||||
pool = nn.AdaptiveMaxPool1d(3)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2))
|
||||
|
||||
# Test 2D input (no batch dimension)
|
||||
x = mx.random.normal((6, 4))
|
||||
pool = nn.AdaptiveMaxPool1d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 4))
|
||||
|
||||
# Test single output
|
||||
x = mx.random.normal((1, 10, 2))
|
||||
pool = nn.AdaptiveMaxPool1d(1)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 1, 2))
|
||||
|
||||
# Test larger output than input
|
||||
x = mx.random.normal((1, 3, 2))
|
||||
pool = nn.AdaptiveMaxPool1d(5)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 5, 2))
|
||||
|
||||
# Test AdaptiveMaxPool2d
|
||||
# Test exact division case (4x4 -> 2x2)
|
||||
x = mx.arange(16, dtype=mx.float32).reshape(1, 4, 4, 1)
|
||||
pool = nn.AdaptiveMaxPool2d((2, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 1))
|
||||
# Check first value manually - max of top-left 2x2 block
|
||||
expected_00 = max(0, 1, 4, 5) # Top-left 2x2 block
|
||||
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
|
||||
|
||||
# Test square output format (int input)
|
||||
pool = nn.AdaptiveMaxPool2d(1)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 1, 1, 1))
|
||||
|
||||
# Test None values (keep original dimension)
|
||||
x = mx.random.normal((2, 6, 8, 3))
|
||||
pool = nn.AdaptiveMaxPool2d((3, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 3))
|
||||
|
||||
pool = nn.AdaptiveMaxPool2d((None, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 2))
|
||||
pool = nn.AdaptiveMaxPool2d((3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 2))
|
||||
|
||||
# Test larger output than input
|
||||
x = mx.random.normal((1, 2, 3, 2))
|
||||
pool = nn.AdaptiveMaxPool2d((4, 5))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 5, 2))
|
||||
|
||||
# Test 3D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 4))
|
||||
pool = nn.AdaptiveMaxPool2d((2, 3))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 4))
|
||||
|
||||
# Test AdaptiveMaxPool3d
|
||||
# Test exact division case (4x4x4 -> 2x2x2)
|
||||
x = mx.arange(64, dtype=mx.float32).reshape(1, 4, 4, 4, 1)
|
||||
pool = nn.AdaptiveMaxPool3d((2, 2, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
|
||||
|
||||
# Test cube output format (int input)
|
||||
pool = nn.AdaptiveMaxPool3d(1)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 1, 1, 1, 1))
|
||||
|
||||
# Test None values (keep original dimensions)
|
||||
x = mx.random.normal((2, 6, 8, 10, 3))
|
||||
pool = nn.AdaptiveMaxPool3d((3, None, 5))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
|
||||
|
||||
pool = nn.AdaptiveMaxPool3d((None, 4, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 9, 2))
|
||||
pool = nn.AdaptiveMaxPool3d((3, 2, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
|
||||
|
||||
# Test larger output than input
|
||||
x = mx.random.normal((1, 2, 3, 2, 2))
|
||||
pool = nn.AdaptiveMaxPool3d((4, 5, 3))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 5, 3, 2))
|
||||
|
||||
# Test 4D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 6, 4))
|
||||
pool = nn.AdaptiveMaxPool3d((2, 3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 2, 4))
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
|
Reference in New Issue
Block a user