From d29770eeaa7605a6e42476fce140ce2d9c22733c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 28 Dec 2023 14:31:10 -0800 Subject: [PATCH 1/3] Update batchnorm to have the running stats in parameters (#305) --- python/mlx/nn/layers/normalization.py | 50 +++++++++++++++----------- python/tests/test_nn.py | 52 ++++++++++++++++++--------- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 9cd578fb2..d5e1a1c6e 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -243,8 +243,15 @@ class BatchNorm(Module): self.bias = mx.zeros((num_features,)) if self.track_running_stats: - self._running_mean = mx.zeros((num_features,)) - self._running_var = mx.ones((num_features,)) + self.running_mean = mx.zeros((num_features,)) + self.running_var = mx.ones((num_features,)) + self.freeze(keys=["running_mean", "running_var"], recurse=False) + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze to make sure that running_mean and var are always + frozen parameters.""" + super().unfreeze(*args, **kwargs) + self.freeze(keys=["running_mean", "running_var"], recurse=False) def _extra_repr(self): return ( @@ -255,46 +262,47 @@ class BatchNorm(Module): def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ - Calculate the mean and variance of the input tensor. + Calculate the mean and variance of the input tensor across the batch + and spatial dimensions. Args: - x (mx.array): Input tensor. + x (array): Input tensor. Returns: tuple: Tuple containing mean and variance. """ reduction_axes = tuple(range(0, x.ndim - 1)) - means = mx.mean(x, axis=reduction_axes, keepdims=True) + + mean = mx.mean(x, axis=reduction_axes, keepdims=True) var = mx.var(x, axis=reduction_axes, keepdims=True) - if self.track_running_stats and self.training: - self._running_mean = ( - 1 - self.momentum - ) * self._running_mean + self.momentum * means - self._running_var = ( - 1 - self.momentum - ) * self._running_var + self.momentum * var - return means, var + return mean, var def __call__(self, x: mx.array) -> mx.array: """ Forward pass of BatchNorm. Args: - x (mx.array): Input tensor. + x (array): Input tensor. Returns: - mx.array: Output tensor. + array: Normalized output tensor. """ - if x.ndim < 2 or x.ndim > 4: raise ValueError( f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" ) - if self.training or not self.track_running_stats: - means, var = self._calc_stats(x) - else: - means, var = self._running_mean, self._running_var - x = (x - means) * mx.rsqrt(var + self.eps) + # Calculate the mean and variance used to normalize the input x. If we + # are in training mode update the running stats if needed. + mean, var = self._calc_stats(x) + if self.training and self.track_running_stats: + mu = self.momentum + self.running_mean = (1 - mu) * self.running_mean + mu * mean + self.running_var = (1 - mu) * self.running_var + mu * var + elif self.track_running_stats: + mean = self.running_mean + var = self.running_var + + x = (x - mean) * mx.rsqrt(var + self.eps) return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2cfac4475..8210b4a48 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -326,8 +326,8 @@ class TestNN(mlx_tests.MLXTestCase): # Batch norm bn = nn.BatchNorm(num_features=4, affine=True) - self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) - self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) + self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) y = bn(x) expected_y = mx.array( [ @@ -342,8 +342,8 @@ class TestNN(mlx_tests.MLXTestCase): expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) self.assertTrue(x.shape == y.shape) self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) - self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) - self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5)) # test eval mode bn.eval() @@ -385,8 +385,8 @@ class TestNN(mlx_tests.MLXTestCase): # Batch norm bn = nn.BatchNorm(num_features=C, affine=True) - self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) - self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) + self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) y = bn(x) self.assertTrue(x.shape == y.shape) expected_y = mx.array( @@ -410,13 +410,33 @@ class TestNN(mlx_tests.MLXTestCase): [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] ) expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) - self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) - self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5)) x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) with self.assertRaises(ValueError): y = bn(x) + # Check that the running stats are in the param dictionary + bn_parameters = bn.parameters() + self.assertIn("running_mean", bn_parameters) + self.assertIn("running_var", bn_parameters) + self.assertIn("weight", bn_parameters) + self.assertIn("bias", bn_parameters) + + bn_trainable = bn.trainable_parameters() + self.assertNotIn("running_mean", bn_trainable) + self.assertNotIn("running_var", bn_trainable) + self.assertIn("weight", bn_trainable) + self.assertIn("bias", bn_trainable) + + bn.unfreeze() + bn_trainable = bn.trainable_parameters() + self.assertNotIn("running_mean", bn_trainable) + self.assertNotIn("running_var", bn_trainable) + self.assertIn("weight", bn_trainable) + self.assertIn("bias", bn_trainable) + def test_batch_norm_stats(self): batch_size = 2 num_features = 4 @@ -427,8 +447,8 @@ class TestNN(mlx_tests.MLXTestCase): batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm._running_mean) - running_var = np.array(batch_norm._running_var) + running_mean = np.array(batch_norm.running_mean) + running_var = np.array(batch_norm.running_var) data = mx.random.normal((batch_size, num_features)) @@ -438,14 +458,14 @@ class TestNN(mlx_tests.MLXTestCase): variances = np.var(np_data, axis=0) running_mean = (1 - momentum) * running_mean + momentum * means running_var = (1 - momentum) * running_var + momentum * variances - self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) - self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm._running_mean) - running_var = np.array(batch_norm._running_var) + running_mean = np.array(batch_norm.running_mean) + running_var = np.array(batch_norm.running_var) data = mx.random.normal((batch_size, h, w, num_features)) normalized_data = batch_norm(data) @@ -454,8 +474,8 @@ class TestNN(mlx_tests.MLXTestCase): variances = np.var(np_data, axis=(0, 1, 2)) running_mean = (1 - momentum) * running_mean + momentum * means running_var = (1 - momentum) * running_var + momentum * variances - self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) - self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) def test_conv1d(self): N = 5 From 473b6b43b428270a8cc38c38a803bdd57ea4c268 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Fri, 29 Dec 2023 06:46:13 +0800 Subject: [PATCH 2/3] Use defaultdict (#307) Co-authored-by: Chunyang Wen --- python/mlx/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 8cb8e90c8..daa387420 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from collections import defaultdict + def tree_map(fn, tree, *rest, is_leaf=None): """Applies ``fn`` to the leaves of the python tree ``tree`` and @@ -128,12 +130,10 @@ def tree_unflatten(tree): is_list = False # collect children - children = {} + children = defaultdict(list) for key, value in tree: current_idx, *next_idx = key.split(".", maxsplit=1) next_idx = "" if not next_idx else next_idx[0] - if current_idx not in children: - children[current_idx] = [] children[current_idx].append((next_idx, value)) # recursively map them to the original container From 2aedf3e791be97201ff7580ed49ef82783ecfb4e Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Fri, 29 Dec 2023 12:55:10 +0800 Subject: [PATCH 3/3] Minor refactor for tree_map and tree_unflatten (#311) * Minor refact for tree_map and tree_unflatten * Remove the if statement --------- Co-authored-by: Chunyang Wen --- python/mlx/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index daa387420..137a8aae4 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None): """ if is_leaf is not None and is_leaf(tree): return fn(tree, *rest) - elif isinstance(tree, list): - return [ - tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) - for i, child in enumerate(tree) - ] - elif isinstance(tree, tuple): - return tuple( + elif isinstance(tree, (list, tuple)): + TreeType = type(tree) + return TreeType( tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) for i, child in enumerate(tree) ) @@ -141,8 +137,8 @@ def tree_unflatten(tree): keys = sorted((int(idx), idx) for idx in children.keys()) l = [] for i, k in keys: - while i > len(l): - l.append({}) + # if i <= len(l), no {} will be appended. + l.extend([{} for _ in range(i - len(l))]) l.append(tree_unflatten(children[k])) return l else: