mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Formatting fixes (#1606)
This commit is contained in:
parent
cb431dfc9f
commit
d8c824c594
@ -450,6 +450,7 @@ class AvgPool3d(_Pool3d):
|
|||||||
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
|
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
|
||||||
>>> pool(x)
|
>>> pool(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel_size: Union[int, Tuple[int, int, int]],
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
@ -1653,30 +1653,34 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
|
expected_avg_pool_output_no_padding_stride_1 = [
|
||||||
[22.5, 23.5, 24.5]],
|
[
|
||||||
[[28.5, 29.5, 30.5],
|
[
|
||||||
[31.5, 32.5, 33.5]]]]
|
[[19.5, 20.5, 21.5], [22.5, 23.5, 24.5]],
|
||||||
]
|
[[28.5, 29.5, 30.5], [31.5, 32.5, 33.5]],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
|
expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
|
||||||
expected_avg_pool_output_padding_1 = [
|
expected_avg_pool_output_padding_1 = [
|
||||||
[[[[0, 0.125, 0.25],
|
[
|
||||||
[1.125, 1.375, 1.625]],
|
[
|
||||||
[[3.375, 3.625, 3.875],
|
[[0, 0.125, 0.25], [1.125, 1.375, 1.625]],
|
||||||
[9, 9.5, 10]]],
|
[[3.375, 3.625, 3.875], [9, 9.5, 10]],
|
||||||
[[[3.375, 3.5, 3.625],
|
],
|
||||||
[7.875, 8.125, 8.375]],
|
[
|
||||||
[[10.125, 10.375, 10.625],
|
[[3.375, 3.5, 3.625], [7.875, 8.125, 8.375]],
|
||||||
[22.5, 23, 23.5]]]]
|
[[10.125, 10.375, 10.625], [22.5, 23, 23.5]],
|
||||||
|
],
|
||||||
]
|
]
|
||||||
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
|
]
|
||||||
[7.5, 8.5, 9.5],
|
expected_irregular_avg_pool_output = [
|
||||||
[10.5, 11.5, 12.5]]],
|
[
|
||||||
[[[31.5, 32.5, 33.5],
|
[[[4.5, 5.5, 6.5], [7.5, 8.5, 9.5], [10.5, 11.5, 12.5]]],
|
||||||
[34.5, 35.5, 36.5],
|
[[[31.5, 32.5, 33.5], [34.5, 35.5, 36.5], [37.5, 38.5, 39.5]]],
|
||||||
[37.5, 38.5, 39.5]]]]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(
|
np.array_equal(
|
||||||
|
Loading…
Reference in New Issue
Block a user