Add AdaptiveAvgPool1d layer

This commit is contained in:
Vincent Amato
2025-08-11 23:10:56 -04:00
parent c320e72635
commit e4530007ae
4 changed files with 106 additions and 0 deletions

View File

@@ -9,6 +9,7 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
AdaptiveAvgPool1d
AdaptiveAvgPool2d
AdaptiveAvgPool3d
ALiBi

View File

@@ -77,6 +77,7 @@ from mlx.nn.layers.normalization import (
RMSNorm,
)
from mlx.nn.layers.pooling import (
AdaptiveAvgPool1d,
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
AvgPool1d,

View File

@@ -409,6 +409,66 @@ class _AdaptivePool(Module):
return f"output_size={self.output_size}"
class AdaptiveAvgPool1d(_AdaptivePool):
r"""Applies 1-dimensional adaptive average pooling.
Spatially downsamples the input by taking the average over pooling regions
such that the output size is L, for any input size.
The parameter can be:
* a single ``int`` -- the target output size.
* ``None`` can be used to keep the input size unchanged.
Args:
output_size (int or None): The target output size L.
Can be an ``int``, or ``None`` which means the size
will be the same as that of the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 4))
>>> pool = nn.AdaptiveAvgPool1d(7)
>>> pool(x)
>>> pool = nn.AdaptiveAvgPool1d(None) # No change
>>> pool(x)
"""
def __init__(self, output_size: Union[int, None]):
super().__init__(output_size)
def __call__(self, x):
output_size = self.output_size
*batch_dims, L, C = x.shape
output_L = L if output_size is None else output_size
if L == output_L:
return x
kernel_L = L // output_L
if L % output_L == 0:
# Efficient path for exact division
new_shape = batch_dims + [output_L, kernel_L, C]
x_reshaped = x.reshape(new_shape)
return mx.mean(x_reshaped, axis=-2)
else:
# Manual indexing for non-exact division
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
values = []
for i in range(output_L):
l_start = i * stride_L
l_end = min(l_start + kernel_L, L)
region = x[..., l_start:l_end, :]
values.append(mx.mean(region, axis=-2))
return mx.stack(values, axis=-2)
class AdaptiveAvgPool2d(_AdaptivePool):
r"""Applies 2-dimensional adaptive average pooling.

View File

@@ -1818,6 +1818,50 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
# Test AdaptiveAvgPool1d
# Test exact division case (8 -> 4)
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
pool = nn.AdaptiveAvgPool1d(4)
output = pool(x)
self.assertEqual(output.shape, (1, 4, 1))
# Check first few values manually
expected_0 = (0 + 1) / 2 # First 2 elements average
expected_1 = (2 + 3) / 2 # Next 2 elements average
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
# Test None value (keep original dimension)
x = mx.random.normal((2, 8, 3))
pool = nn.AdaptiveAvgPool1d(None)
output = pool(x)
self.assertEqual(output.shape, (2, 8, 3))
# Should be identical to input when output_size is None
self.assertTrue(mx.allclose(output, x))
# Test non-exact division case
x = mx.random.normal((1, 7, 2))
pool = nn.AdaptiveAvgPool1d(3)
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2))
# Test 2D input (no batch dimension)
x = mx.random.normal((6, 4))
pool = nn.AdaptiveAvgPool1d(2)
output = pool(x)
self.assertEqual(output.shape, (2, 4))
# Test single output
x = mx.random.normal((1, 10, 2))
pool = nn.AdaptiveAvgPool1d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 2))
# Test larger output than input (should still work)
x = mx.random.normal((1, 3, 2))
pool = nn.AdaptiveAvgPool1d(5)
output = pool(x)
self.assertEqual(output.shape, (1, 5, 2))
# Test AdaptiveAvgPool2d
# Test exact division case (8x8 -> 4x4)
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)