mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Adds 3D pooling (#1526)
This commit is contained in:
@@ -1589,6 +1589,123 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
# Test 3d pooling
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
|
||||
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
|
||||
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
|
||||
],
|
||||
[
|
||||
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
|
||||
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
|
||||
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
expected_max_pool_output_no_padding_stride_1 = [
|
||||
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
|
||||
]
|
||||
|
||||
expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
|
||||
expected_max_pool_output_padding_1 = [
|
||||
[
|
||||
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
|
||||
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
|
||||
]
|
||||
]
|
||||
expected_irregular_max_pool_output = [
|
||||
[
|
||||
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
|
||||
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
|
||||
]
|
||||
]
|
||||
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
|
||||
expected_irregular_max_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
|
||||
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
|
||||
[22.5, 23.5, 24.5]],
|
||||
[[28.5, 29.5, 30.5],
|
||||
[31.5, 32.5, 33.5]]]]
|
||||
]
|
||||
|
||||
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
|
||||
expected_avg_pool_output_padding_1 = [
|
||||
[[[[0, 0.125, 0.25],
|
||||
[1.125, 1.375, 1.625]],
|
||||
[[3.375, 3.625, 3.875],
|
||||
[9, 9.5, 10]]],
|
||||
[[[3.375, 3.5, 3.625],
|
||||
[7.875, 8.125, 8.375]],
|
||||
[[10.125, 10.375, 10.625],
|
||||
[22.5, 23, 23.5]]]]
|
||||
]
|
||||
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
|
||||
[7.5, 8.5, 9.5],
|
||||
[10.5, 11.5, 12.5]]],
|
||||
[[[31.5, 32.5, 33.5],
|
||||
[34.5, 35.5, 36.5],
|
||||
[37.5, 38.5, 39.5]]]]
|
||||
]
|
||||
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
|
||||
expected_irregular_avg_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
|
Reference in New Issue
Block a user