mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add AdaptiveAvgPool1d layer
This commit is contained in:
@@ -9,6 +9,7 @@ Layers
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
AdaptiveAvgPool1d
|
||||
AdaptiveAvgPool2d
|
||||
AdaptiveAvgPool3d
|
||||
ALiBi
|
||||
|
@@ -77,6 +77,7 @@ from mlx.nn.layers.normalization import (
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.pooling import (
|
||||
AdaptiveAvgPool1d,
|
||||
AdaptiveAvgPool2d,
|
||||
AdaptiveAvgPool3d,
|
||||
AvgPool1d,
|
||||
|
@@ -409,6 +409,66 @@ class _AdaptivePool(Module):
|
||||
return f"output_size={self.output_size}"
|
||||
|
||||
|
||||
class AdaptiveAvgPool1d(_AdaptivePool):
|
||||
r"""Applies 1-dimensional adaptive average pooling.
|
||||
|
||||
Spatially downsamples the input by taking the average over pooling regions
|
||||
such that the output size is L, for any input size.
|
||||
|
||||
The parameter can be:
|
||||
|
||||
* a single ``int`` -- the target output size.
|
||||
* ``None`` can be used to keep the input size unchanged.
|
||||
|
||||
Args:
|
||||
output_size (int or None): The target output size L.
|
||||
Can be an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 4))
|
||||
>>> pool = nn.AdaptiveAvgPool1d(7)
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveAvgPool1d(None) # No change
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(self, output_size: Union[int, None]):
|
||||
super().__init__(output_size)
|
||||
|
||||
def __call__(self, x):
|
||||
output_size = self.output_size
|
||||
|
||||
*batch_dims, L, C = x.shape
|
||||
|
||||
output_L = L if output_size is None else output_size
|
||||
|
||||
if L == output_L:
|
||||
return x
|
||||
|
||||
kernel_L = L // output_L
|
||||
|
||||
if L % output_L == 0:
|
||||
# Efficient path for exact division
|
||||
new_shape = batch_dims + [output_L, kernel_L, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
return mx.mean(x_reshaped, axis=-2)
|
||||
else:
|
||||
# Manual indexing for non-exact division
|
||||
stride_L = (L - kernel_L) // (output_L - 1) if output_L > 1 else 1
|
||||
|
||||
values = []
|
||||
for i in range(output_L):
|
||||
l_start = i * stride_L
|
||||
l_end = min(l_start + kernel_L, L)
|
||||
region = x[..., l_start:l_end, :]
|
||||
values.append(mx.mean(region, axis=-2))
|
||||
|
||||
return mx.stack(values, axis=-2)
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(_AdaptivePool):
|
||||
r"""Applies 2-dimensional adaptive average pooling.
|
||||
|
||||
|
@@ -1818,6 +1818,50 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
# Test AdaptiveAvgPool1d
|
||||
# Test exact division case (8 -> 4)
|
||||
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool1d(4)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 1))
|
||||
# Check first few values manually
|
||||
expected_0 = (0 + 1) / 2 # First 2 elements average
|
||||
expected_1 = (2 + 3) / 2 # Next 2 elements average
|
||||
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
|
||||
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
|
||||
|
||||
# Test None value (keep original dimension)
|
||||
x = mx.random.normal((2, 8, 3))
|
||||
pool = nn.AdaptiveAvgPool1d(None)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 8, 3))
|
||||
# Should be identical to input when output_size is None
|
||||
self.assertTrue(mx.allclose(output, x))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(3)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2))
|
||||
|
||||
# Test 2D input (no batch dimension)
|
||||
x = mx.random.normal((6, 4))
|
||||
pool = nn.AdaptiveAvgPool1d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 4))
|
||||
|
||||
# Test single output
|
||||
x = mx.random.normal((1, 10, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(1)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 1, 2))
|
||||
|
||||
# Test larger output than input (should still work)
|
||||
x = mx.random.normal((1, 3, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(5)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 5, 2))
|
||||
|
||||
# Test AdaptiveAvgPool2d
|
||||
# Test exact division case (8x8 -> 4x4)
|
||||
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
|
||||
|
Reference in New Issue
Block a user