Add AdaptiveAvgPool2d and AdaptiveAvgPool3d

This commit is contained in:
Vincent Amato
2025-08-11 21:17:49 -04:00
parent 7fde1b6a1e
commit 634ce07a3e
4 changed files with 310 additions and 0 deletions

View File

@@ -1818,6 +1818,79 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
# Test AdaptiveAvgPool2d
# Test exact division case (8x8 -> 4x4)
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
pool = nn.AdaptiveAvgPool2d((4, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 4, 1))
# Check first few values manually
expected_00 = (0 + 1 + 8 + 9) / 4 # Top-left 2x2 block
expected_01 = (2 + 3 + 10 + 11) / 4 # Next 2x2 block
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
self.assertAlmostEqual(output[0, 0, 1, 0].item(), expected_01, places=5)
# Test square output format (int input)
pool = nn.AdaptiveAvgPool2d(2)
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 1))
# Test None values (keep original dimension)
x = mx.random.normal((2, 6, 8, 3))
pool = nn.AdaptiveAvgPool2d((3, None))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 3))
pool = nn.AdaptiveAvgPool2d((None, 4))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 2))
pool = nn.AdaptiveAvgPool2d((3, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 2))
# Test 3D input (no batch dimension)
x = mx.random.normal((6, 6, 4))
pool = nn.AdaptiveAvgPool2d((2, 3))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 4))
# Test AdaptiveAvgPool3d
# Test exact division case (8x8x8 -> 4x4x4)
x = mx.arange(512, dtype=mx.float32).reshape(1, 8, 8, 8, 1)
pool = nn.AdaptiveAvgPool3d((4, 4, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 4, 4, 1))
# Test cube output format (int input)
pool = nn.AdaptiveAvgPool3d(2)
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
# Test None values (keep original dimensions)
x = mx.random.normal((2, 6, 8, 10, 3))
pool = nn.AdaptiveAvgPool3d((3, None, 5))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
pool = nn.AdaptiveAvgPool3d((None, 4, None))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 9, 2))
pool = nn.AdaptiveAvgPool3d((3, 2, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
# Test 4D input (no batch dimension)
x = mx.random.normal((6, 6, 6, 4))
pool = nn.AdaptiveAvgPool3d((2, 3, 2))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 2, 4))
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):