From e9ca65c939dd4ac0d93dbbb27d88e48c66710ea7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Jan 2024 11:54:51 -0800 Subject: [PATCH] Fix BN stats to not expand shape (#409) * fix BN stats to not expand shape * nit --- python/mlx/nn/layers/normalization.py | 4 ++-- python/tests/test_nn.py | 29 ++++++++++++++------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index b2b60ccba..42107d658 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -333,8 +333,8 @@ class BatchNorm(Module): """ reduction_axes = tuple(range(0, x.ndim - 1)) - mean = mx.mean(x, axis=reduction_axes, keepdims=True) - var = mx.var(x, axis=reduction_axes, keepdims=True) + mean = mx.mean(x, axis=reduction_axes) + var = mx.var(x, axis=reduction_axes) return mean, var diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 04d849a42..9ae8a2cd1 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase): batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm.running_mean) - running_var = np.array(batch_norm.running_var) + running_mean = batch_norm.running_mean + running_var = batch_norm.running_var data = mx.random.normal((batch_size, num_features)) normalized_data = batch_norm(data) - np_data = np.array(data) - means = np.mean(np_data, axis=0) - variances = np.var(np_data, axis=0) + means = mx.mean(data, axis=0) + variances = mx.var(data, axis=0) running_mean = (1 - momentum) * running_mean + momentum * means running_var = (1 - momentum) * running_var + momentum * variances - self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) - self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) + self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) + self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5)) batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm.running_mean) - running_var = np.array(batch_norm.running_var) + running_mean = batch_norm.running_mean + running_var = batch_norm.running_var data = mx.random.normal((batch_size, h, w, num_features)) normalized_data = batch_norm(data) - np_data = np.array(data) - means = np.mean(np_data, axis=(0, 1, 2)) - variances = np.var(np_data, axis=(0, 1, 2)) + means = mx.mean(data, axis=(0, 1, 2)) + variances = mx.var(data, axis=(0, 1, 2)) running_mean = (1 - momentum) * running_mean + momentum * means running_var = (1 - momentum) * running_var + momentum * variances - self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) - self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) + self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) + self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5)) + + self.assertEqual(batch_norm.running_mean.shape, running_mean.shape) + self.assertEqual(batch_norm.running_var.shape, running_var.shape) def test_conv1d(self): N = 5