Formatting fixes (#1606)

This commit is contained in:
Angelos Katharopoulos 2024-11-20 15:30:36 -08:00 committed by GitHub
parent cb431dfc9f
commit d8c824c594
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 20 deletions

View File

@ -450,6 +450,7 @@ class AvgPool3d(_Pool3d):
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2) >>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
>>> pool(x) >>> pool(x)
""" """
def __init__( def __init__(
self, self,
kernel_size: Union[int, Tuple[int, int, int]], kernel_size: Union[int, Tuple[int, int, int]],

View File

@ -1653,29 +1653,33 @@ class TestLayers(mlx_tests.MLXTestCase):
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", "MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
) )
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5], expected_avg_pool_output_no_padding_stride_1 = [
[22.5, 23.5, 24.5]], [
[[28.5, 29.5, 30.5], [
[31.5, 32.5, 33.5]]]] [[19.5, 20.5, 21.5], [22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5], [31.5, 32.5, 33.5]],
]
]
] ]
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]] expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
expected_avg_pool_output_padding_1 = [ expected_avg_pool_output_padding_1 = [
[[[[0, 0.125, 0.25], [
[1.125, 1.375, 1.625]], [
[[3.375, 3.625, 3.875], [[0, 0.125, 0.25], [1.125, 1.375, 1.625]],
[9, 9.5, 10]]], [[3.375, 3.625, 3.875], [9, 9.5, 10]],
[[[3.375, 3.5, 3.625], ],
[7.875, 8.125, 8.375]], [
[[10.125, 10.375, 10.625], [[3.375, 3.5, 3.625], [7.875, 8.125, 8.375]],
[22.5, 23, 23.5]]]] [[10.125, 10.375, 10.625], [22.5, 23, 23.5]],
],
]
]
expected_irregular_avg_pool_output = [
[
[[[4.5, 5.5, 6.5], [7.5, 8.5, 9.5], [10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5], [34.5, 35.5, 36.5], [37.5, 38.5, 39.5]]],
] ]
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5],
[10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5],
[34.5, 35.5, 36.5],
[37.5, 38.5, 39.5]]]]
] ]
self.assertTrue( self.assertTrue(