mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add AdaptiveAvgPool2d and AdaptiveAvgPool3d
This commit is contained in:
@@ -1818,6 +1818,79 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
# Test AdaptiveAvgPool2d
|
||||
# Test exact division case (8x8 -> 4x4)
|
||||
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool2d((4, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 4, 1))
|
||||
# Check first few values manually
|
||||
expected_00 = (0 + 1 + 8 + 9) / 4 # Top-left 2x2 block
|
||||
expected_01 = (2 + 3 + 10 + 11) / 4 # Next 2x2 block
|
||||
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
|
||||
self.assertAlmostEqual(output[0, 0, 1, 0].item(), expected_01, places=5)
|
||||
|
||||
# Test square output format (int input)
|
||||
pool = nn.AdaptiveAvgPool2d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 1))
|
||||
|
||||
# Test None values (keep original dimension)
|
||||
x = mx.random.normal((2, 6, 8, 3))
|
||||
pool = nn.AdaptiveAvgPool2d((3, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 3))
|
||||
|
||||
pool = nn.AdaptiveAvgPool2d((None, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 2))
|
||||
pool = nn.AdaptiveAvgPool2d((3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 2))
|
||||
|
||||
# Test 3D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 4))
|
||||
pool = nn.AdaptiveAvgPool2d((2, 3))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 4))
|
||||
|
||||
# Test AdaptiveAvgPool3d
|
||||
# Test exact division case (8x8x8 -> 4x4x4)
|
||||
x = mx.arange(512, dtype=mx.float32).reshape(1, 8, 8, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool3d((4, 4, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 4, 4, 1))
|
||||
|
||||
# Test cube output format (int input)
|
||||
pool = nn.AdaptiveAvgPool3d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
|
||||
|
||||
# Test None values (keep original dimensions)
|
||||
x = mx.random.normal((2, 6, 8, 10, 3))
|
||||
pool = nn.AdaptiveAvgPool3d((3, None, 5))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
|
||||
|
||||
pool = nn.AdaptiveAvgPool3d((None, 4, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 9, 2))
|
||||
pool = nn.AdaptiveAvgPool3d((3, 2, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
|
||||
|
||||
# Test 4D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 6, 4))
|
||||
pool = nn.AdaptiveAvgPool3d((2, 3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 2, 4))
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
|
||||
Reference in New Issue
Block a user