Formatting fixes (#1606)

This commit is contained in:
Angelos Katharopoulos 2024-11-20 15:30:36 -08:00 committed by GitHub
parent cb431dfc9f
commit d8c824c594
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 20 deletions

View File

@ -450,6 +450,7 @@ class AvgPool3d(_Pool3d):
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],

View File

@ -1653,29 +1653,33 @@ class TestLayers(mlx_tests.MLXTestCase):
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
[22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5],
[31.5, 32.5, 33.5]]]]
expected_avg_pool_output_no_padding_stride_1 = [
[
[
[[19.5, 20.5, 21.5], [22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5], [31.5, 32.5, 33.5]],
]
]
]
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
expected_avg_pool_output_padding_1 = [
[[[[0, 0.125, 0.25],
[1.125, 1.375, 1.625]],
[[3.375, 3.625, 3.875],
[9, 9.5, 10]]],
[[[3.375, 3.5, 3.625],
[7.875, 8.125, 8.375]],
[[10.125, 10.375, 10.625],
[22.5, 23, 23.5]]]]
[
[
[[0, 0.125, 0.25], [1.125, 1.375, 1.625]],
[[3.375, 3.625, 3.875], [9, 9.5, 10]],
],
[
[[3.375, 3.5, 3.625], [7.875, 8.125, 8.375]],
[[10.125, 10.375, 10.625], [22.5, 23, 23.5]],
],
]
]
expected_irregular_avg_pool_output = [
[
[[[4.5, 5.5, 6.5], [7.5, 8.5, 9.5], [10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5], [34.5, 35.5, 36.5], [37.5, 38.5, 39.5]]],
]
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5],
[10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5],
[34.5, 35.5, 36.5],
[37.5, 38.5, 39.5]]]]
]
self.assertTrue(