Adds 3D pooling (#1526)

This commit is contained in:
Saanidhya
2024-11-19 19:45:24 -05:00
committed by GitHub
parent 61d787726a
commit cb431dfc9f
3 changed files with 250 additions and 1 deletions

View File

@@ -1589,6 +1589,123 @@ class TestLayers(mlx_tests.MLXTestCase):
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
)
# Test 3d pooling
x = mx.array(
[
[
[
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
],
[
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
],
]
]
)
expected_max_pool_output_no_padding_stride_1 = [
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
]
expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
expected_max_pool_output_padding_1 = [
[
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
]
]
expected_irregular_max_pool_output = [
[
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
]
]
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_max_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_max_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_max_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_max_pool_output,
)
)
self.assertEqual(
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
[22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5],
[31.5, 32.5, 33.5]]]]
]
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
expected_avg_pool_output_padding_1 = [
[[[[0, 0.125, 0.25],
[1.125, 1.375, 1.625]],
[[3.375, 3.625, 3.875],
[9, 9.5, 10]]],
[[[3.375, 3.5, 3.625],
[7.875, 8.125, 8.375]],
[[10.125, 10.375, 10.625],
[22.5, 23, 23.5]]]]
]
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5],
[10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5],
[34.5, 35.5, 36.5],
[37.5, 38.5, 39.5]]]]
]
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_avg_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_avg_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_avg_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_avg_pool_output,
)
)
self.assertEqual(
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
def test_set_dtype(self):
def assert_dtype(layer, dtype):