mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add AdaptiveAvgPool1d layer
This commit is contained in:
@@ -1818,6 +1818,50 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
# Test AdaptiveAvgPool1d
|
||||
# Test exact division case (8 -> 4)
|
||||
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool1d(4)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 1))
|
||||
# Check first few values manually
|
||||
expected_0 = (0 + 1) / 2 # First 2 elements average
|
||||
expected_1 = (2 + 3) / 2 # Next 2 elements average
|
||||
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
|
||||
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
|
||||
|
||||
# Test None value (keep original dimension)
|
||||
x = mx.random.normal((2, 8, 3))
|
||||
pool = nn.AdaptiveAvgPool1d(None)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 8, 3))
|
||||
# Should be identical to input when output_size is None
|
||||
self.assertTrue(mx.allclose(output, x))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(3)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2))
|
||||
|
||||
# Test 2D input (no batch dimension)
|
||||
x = mx.random.normal((6, 4))
|
||||
pool = nn.AdaptiveAvgPool1d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 4))
|
||||
|
||||
# Test single output
|
||||
x = mx.random.normal((1, 10, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(1)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 1, 2))
|
||||
|
||||
# Test larger output than input (should still work)
|
||||
x = mx.random.normal((1, 3, 2))
|
||||
pool = nn.AdaptiveAvgPool1d(5)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 5, 2))
|
||||
|
||||
# Test AdaptiveAvgPool2d
|
||||
# Test exact division case (8x8 -> 4x4)
|
||||
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
|
||||
|
||||
Reference in New Issue
Block a user