From 7ec3cadf98ac075f9d6b8b0e7ce7ae4d018b1a83 Mon Sep 17 00:00:00 2001 From: m0saan Date: Tue, 19 Dec 2023 10:06:35 +0100 Subject: [PATCH] added test cases for batch norm on 3D input & refactored code ^^ --- python/mlx/nn/layers/normalization.py | 31 +++++++++++++++++------- python/tests/test_nn.py | 34 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 4f7ad7b60..541912ce3 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -253,12 +253,32 @@ class BatchNorm1d(Module): if self.track_running_stats and self.training: self.running_mean = ( 1 - self.momentum - ) * self.running_mean + self.momentum * means.squeeze() + ) * self.running_mean + self.momentum * means self.running_var = ( 1 - self.momentum - ) * self.running_var + self.momentum * var.squeeze() + ) * self.running_var + self.momentum * var return means, var + def _check_and_expand_dims(self, x: mx.array): + """ + Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. + + Args: + x (mx.array): Input tensor. + """ + + if x.ndim != 2 and x.ndim != 3: + raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)") + + if x.ndim == 3 and self.weight.ndim != x.ndim: + self.weight = mx.expand_dims(self.weight, [0, 2]) + self.bias = mx.expand_dims(self.bias, [0, 2]) + + if self.track_running_stats: + if x.ndim == 3 and self.running_mean.ndim != x.ndim: + self.running_mean = mx.expand_dims(self.running_mean, [0, 2]) + self.running_var = mx.expand_dims(self.running_var, [0, 2]) + def __call__(self, x: mx.array): """ Forward pass of BatchNorm1d. @@ -270,12 +290,7 @@ class BatchNorm1d(Module): mx.array: Output tensor. """ - if x.ndim != 2 and x.ndim != 3: - raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)") - - if x.ndim == 3: - self.weight = mx.expand_dims(self.weight, [0, 2]) - self.bias = mx.expand_dims(self.bias, [0, 2]) + self._check_and_expand_dims(x) if self.training or not self.track_running_stats: means, var = self._calc_stats(x) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 63805ab44..82850ca72 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -376,6 +376,40 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(x.shape == y.shape) self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # test with 3D input + mx.random.seed(42) + x = mx.random.normal((2, 4, 3), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm1d(num_features=4, affine=True) + self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) + self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) + y = bn(x) + self.assertTrue(x.shape == y.shape) + expected_y = mx.array( + [ + [ + [-0.285056, -0.657241, 0.584881], + [1.079424, 0.795527, 0.163417], + [-0.351929, 0.669030, 1.713490], + [-0.679080, -1.467115, 1.077580], + ], + [ + [-0.091968, -1.362007, 1.811391], + [-1.654407, -1.017945, 0.633983], + [-1.309168, 0.148356, -0.869779], + [-0.742132, 1.037774, 0.772974], + ], + ] + ) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + expected_mean = mx.array( + [[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]] + ) + expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]]) + self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5)) + self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12