From d8c824c5943bb7cb06642e79ada476655571f0aa Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Nov 2024 15:30:36 -0800 Subject: [PATCH] Formatting fixes (#1606) --- python/mlx/nn/layers/pooling.py | 1 + python/tests/test_nn.py | 44 ++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index dd5c67696..0610a8a71 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -450,6 +450,7 @@ class AvgPool3d(_Pool3d): >>> pool = nn.AvgPool3d(kernel_size=2, stride=2) >>> pool(x) """ + def __init__( self, kernel_size: Union[int, Tuple[int, int, int]], diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e89fd5252..44cbdaf1e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1653,30 +1653,34 @@ class TestLayers(mlx_tests.MLXTestCase): "MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", ) - expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5], - [22.5, 23.5, 24.5]], - [[28.5, 29.5, 30.5], - [31.5, 32.5, 33.5]]]] - ] + expected_avg_pool_output_no_padding_stride_1 = [ + [ + [ + [[19.5, 20.5, 21.5], [22.5, 23.5, 24.5]], + [[28.5, 29.5, 30.5], [31.5, 32.5, 33.5]], + ] + ] + ] expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]] expected_avg_pool_output_padding_1 = [ - [[[[0, 0.125, 0.25], - [1.125, 1.375, 1.625]], - [[3.375, 3.625, 3.875], - [9, 9.5, 10]]], - [[[3.375, 3.5, 3.625], - [7.875, 8.125, 8.375]], - [[10.125, 10.375, 10.625], - [22.5, 23, 23.5]]]] + [ + [ + [[0, 0.125, 0.25], [1.125, 1.375, 1.625]], + [[3.375, 3.625, 3.875], [9, 9.5, 10]], + ], + [ + [[3.375, 3.5, 3.625], [7.875, 8.125, 8.375]], + [[10.125, 10.375, 10.625], [22.5, 23, 23.5]], + ], ] - expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5], - [7.5, 8.5, 9.5], - [10.5, 11.5, 12.5]]], - [[[31.5, 32.5, 33.5], - [34.5, 35.5, 36.5], - [37.5, 38.5, 39.5]]]] - ] + ] + expected_irregular_avg_pool_output = [ + [ + [[[4.5, 5.5, 6.5], [7.5, 8.5, 9.5], [10.5, 11.5, 12.5]]], + [[[31.5, 32.5, 33.5], [34.5, 35.5, 36.5], [37.5, 38.5, 39.5]]], + ] + ] self.assertTrue( np.array_equal(