diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index cf50afa0e..8fb5cf15a 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -181,8 +181,8 @@ class GroupNorm(Module): x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x - -class BatchNorm1d(Module): + +class BatchNorm(Module): r"""Applies Batch Normalization over a 2D or 3D input. Computes @@ -211,7 +211,7 @@ class BatchNorm1d(Module): >>> bn = nn.BatchNorm1d(num_features=4, affine=True) >>> output = bn(x) """ - + def __init__( self, num_features: int, @@ -221,11 +221,13 @@ class BatchNorm1d(Module): track_running_stats: bool = True, ): super().__init__() + self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats + self.dims_expanded = False if self.affine: self.weight = mx.ones((num_features,)) @@ -237,6 +239,37 @@ class BatchNorm1d(Module): def _extra_repr(self): return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" + + def _check_and_expand_dims(self, x: mx.array): + """ + Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. + + Args: + x (mx.array): Input tensor. + """ + + num_dims = len(x.shape) + dims_dict = { + 2: ((1, self.num_features), (0,)), + 3: ((1, self.num_features, 1), (0, 2)), + 4: ((1, self.num_features, 1, 1), (0, 2, 3)), + } + + if num_dims not in dims_dict: + raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})") + + shape, self.reduction_axes = dims_dict[num_dims] + + if self.affine: + self.weight = mx.expand_dims(self.weight, self.reduction_axes) + self.bias = mx.expand_dims(self.bias, self.reduction_axes) + + if self.track_running_stats: + self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes) + self.running_var = mx.expand_dims(self.running_var, self.reduction_axes) + + self.dims_expanded = True + def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -249,12 +282,8 @@ class BatchNorm1d(Module): tuple: Tuple containing mean and variance. """ - if len(x.shape) == 2: - means = mx.mean(x, axis=0, keepdims=True) - var = mx.var(x, axis=0, keepdims=True) - else: - means = mx.mean(x, axis=(0, 2), keepdims=True) - var = mx.var(x, axis=(0, 2), keepdims=True) + means = mx.mean(x, axis=self.reduction_axes, keepdims=True) + var = mx.var(x, axis=self.reduction_axes, keepdims=True) if self.track_running_stats and self.training: self.running_mean = ( @@ -265,29 +294,9 @@ class BatchNorm1d(Module): ) * self.running_var + self.momentum * var return means, var - def _check_and_expand_dims(self, x: mx.array): + def __call__(self, x: mx.array) -> mx.array: """ - Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. - - Args: - x (mx.array): Input tensor. - """ - - if x.ndim != 2 and x.ndim != 3: - raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)") - - if x.ndim == 3 and self.weight.ndim != x.ndim: - self.weight = mx.expand_dims(self.weight, [0, 2]) - self.bias = mx.expand_dims(self.bias, [0, 2]) - - if self.track_running_stats: - if x.ndim == 3 and self.running_mean.ndim != x.ndim: - self.running_mean = mx.expand_dims(self.running_mean, [0, 2]) - self.running_var = mx.expand_dims(self.running_var, [0, 2]) - - def __call__(self, x: mx.array): - """ - Forward pass of BatchNorm1d. + Forward pass of BatchNorm. Args: x (mx.array): Input tensor. @@ -295,12 +304,13 @@ class BatchNorm1d(Module): Returns: mx.array: Output tensor. """ - - self._check_and_expand_dims(x) + + if not self.dims_expanded: + self._check_and_expand_dims(x) if self.training or not self.track_running_stats: means, var = self._calc_stats(x) else: means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x + self.bias) if "weight" in self else x + return (self.weight * x + self.bias) if "weight" in self else x \ No newline at end of file diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 82850ca72..64fc4abc8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -325,7 +325,7 @@ class TestNN(mlx_tests.MLXTestCase): x = mx.random.normal((5, 4), dtype=mx.float32) # Batch norm - bn = nn.BatchNorm1d(num_features=4, affine=True) + bn = nn.BatchNorm(num_features=4, affine=True) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) y = bn(x) @@ -362,7 +362,7 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) # test_no_affine - bn = nn.BatchNorm1d(num_features=4, affine=False) + bn = nn.BatchNorm(num_features=4, affine=False) y = bn(x) expected_y = mx.array( [ @@ -381,7 +381,7 @@ class TestNN(mlx_tests.MLXTestCase): x = mx.random.normal((2, 4, 3), dtype=mx.float32) # Batch norm - bn = nn.BatchNorm1d(num_features=4, affine=True) + bn = nn.BatchNorm(num_features=4, affine=True) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) y = bn(x)