Add AdaptiveMaxPool1d, AdaptiveMaxPool2d, and AdaptiveMaxPool3d layers

This commit is contained in:
Vincent Amato
2025-08-11 23:45:52 -04:00
parent 7fde1b6a1e
commit 68c7be55bb
4 changed files with 396 additions and 0 deletions

View File

@@ -1818,6 +1818,133 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
# Test AdaptiveMaxPool1d
# Test exact division case (8 -> 4)
x = mx.arange(8, dtype=mx.float32).reshape(1, 8, 1)
pool = nn.AdaptiveMaxPool1d(4)
output = pool(x)
self.assertEqual(output.shape, (1, 4, 1))
# Check first few values manually
expected_0 = max(0, 1) # First 2 elements max
expected_1 = max(2, 3) # Next 2 elements max
self.assertAlmostEqual(output[0, 0, 0].item(), expected_0, places=5)
self.assertAlmostEqual(output[0, 1, 0].item(), expected_1, places=5)
# Test None value (keep original dimension)
x = mx.random.normal((2, 8, 3))
pool = nn.AdaptiveMaxPool1d(None)
output = pool(x)
self.assertEqual(output.shape, (2, 8, 3))
# Should be identical to input when output_size is None
self.assertTrue(mx.allclose(output, x))
# Test non-exact division case
x = mx.random.normal((1, 7, 2))
pool = nn.AdaptiveMaxPool1d(3)
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2))
# Test 2D input (no batch dimension)
x = mx.random.normal((6, 4))
pool = nn.AdaptiveMaxPool1d(2)
output = pool(x)
self.assertEqual(output.shape, (2, 4))
# Test single output
x = mx.random.normal((1, 10, 2))
pool = nn.AdaptiveMaxPool1d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 2))
# Test larger output than input
x = mx.random.normal((1, 3, 2))
pool = nn.AdaptiveMaxPool1d(5)
output = pool(x)
self.assertEqual(output.shape, (1, 5, 2))
# Test AdaptiveMaxPool2d
# Test exact division case (4x4 -> 2x2)
x = mx.arange(16, dtype=mx.float32).reshape(1, 4, 4, 1)
pool = nn.AdaptiveMaxPool2d((2, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 1))
# Check first value manually - max of top-left 2x2 block
expected_00 = max(0, 1, 4, 5) # Top-left 2x2 block
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
# Test square output format (int input)
pool = nn.AdaptiveMaxPool2d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 1, 1))
# Test None values (keep original dimension)
x = mx.random.normal((2, 6, 8, 3))
pool = nn.AdaptiveMaxPool2d((3, None))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 3))
pool = nn.AdaptiveMaxPool2d((None, 4))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 2))
pool = nn.AdaptiveMaxPool2d((3, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 2))
# Test larger output than input
x = mx.random.normal((1, 2, 3, 2))
pool = nn.AdaptiveMaxPool2d((4, 5))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 5, 2))
# Test 3D input (no batch dimension)
x = mx.random.normal((6, 6, 4))
pool = nn.AdaptiveMaxPool2d((2, 3))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 4))
# Test AdaptiveMaxPool3d
# Test exact division case (4x4x4 -> 2x2x2)
x = mx.arange(64, dtype=mx.float32).reshape(1, 4, 4, 4, 1)
pool = nn.AdaptiveMaxPool3d((2, 2, 2))
output = pool(x)
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
# Test cube output format (int input)
pool = nn.AdaptiveMaxPool3d(1)
output = pool(x)
self.assertEqual(output.shape, (1, 1, 1, 1, 1))
# Test None values (keep original dimensions)
x = mx.random.normal((2, 6, 8, 10, 3))
pool = nn.AdaptiveMaxPool3d((3, None, 5))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
pool = nn.AdaptiveMaxPool3d((None, 4, None))
output = pool(x)
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
# Test non-exact division case
x = mx.random.normal((1, 7, 5, 9, 2))
pool = nn.AdaptiveMaxPool3d((3, 2, 4))
output = pool(x)
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
# Test larger output than input
x = mx.random.normal((1, 2, 3, 2, 2))
pool = nn.AdaptiveMaxPool3d((4, 5, 3))
output = pool(x)
self.assertEqual(output.shape, (1, 4, 5, 3, 2))
# Test 4D input (no batch dimension)
x = mx.random.normal((6, 6, 6, 4))
pool = nn.AdaptiveMaxPool3d((2, 3, 2))
output = pool(x)
self.assertEqual(output.shape, (2, 3, 2, 4))
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):