diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 8fb5cf15a..8ab41de13 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -181,7 +181,7 @@ class GroupNorm(Module): x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x - + class BatchNorm(Module): r"""Applies Batch Normalization over a 2D or 3D input. @@ -211,7 +211,7 @@ class BatchNorm(Module): >>> bn = nn.BatchNorm1d(num_features=4, affine=True) >>> output = bn(x) """ - + def __init__( self, num_features: int, @@ -239,7 +239,7 @@ class BatchNorm(Module): def _extra_repr(self): return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" - + def _check_and_expand_dims(self, x: mx.array): """ Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. @@ -247,7 +247,7 @@ class BatchNorm(Module): Args: x (mx.array): Input tensor. """ - + num_dims = len(x.shape) dims_dict = { 2: ((1, self.num_features), (0,)), @@ -259,17 +259,16 @@ class BatchNorm(Module): raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})") shape, self.reduction_axes = dims_dict[num_dims] - + if self.affine: self.weight = mx.expand_dims(self.weight, self.reduction_axes) self.bias = mx.expand_dims(self.bias, self.reduction_axes) - + if self.track_running_stats: self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes) self.running_var = mx.expand_dims(self.running_var, self.reduction_axes) - - self.dims_expanded = True + self.dims_expanded = True def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -304,7 +303,7 @@ class BatchNorm(Module): Returns: mx.array: Output tensor. """ - + if not self.dims_expanded: self._check_and_expand_dims(x) @@ -313,4 +312,4 @@ class BatchNorm(Module): else: means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x + self.bias) if "weight" in self else x \ No newline at end of file + return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 64fc4abc8..803a7cb72 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +from unittest.mock import Mock, patch import mlx.core as mx import mlx.nn as nn @@ -410,6 +411,60 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5)) self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5)) + def test_batch_norm_stats(self): + batch_size = 4 + num_features = 32 + num_channels = 32 + h = 28 + w = 28 + num_iterations = 100 + momentum = 0.1 + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm.running_mean.tolist()) + running_var = np.array(batch_norm.running_var.tolist()) + + data = mx.random.normal((batch_size * num_features,)).reshape( + (batch_size, num_features) + ) + + for _ in range(num_iterations): + normalized_data = batch_norm(data) + means = np.mean(data.tolist(), axis=0) + variances = np.var(data.tolist(), axis=0) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5) + assert np.allclose(batch_norm.running_var, running_var, atol=1e-5) + data = normalized_data + + batch_norm = nn.BatchNorm(num_channels) + + batch_norm.train() + running_mean = np.array(batch_norm.running_mean.tolist()).reshape( + 1, num_channels, 1, 1 + ) + running_var = np.array(batch_norm.running_var.tolist()).reshape( + 1, num_channels, 1, 1 + ) + data = mx.random.normal((batch_size, num_channels, h, w)) + + for _ in range(num_iterations): + normalized_data = batch_norm(data) + means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape( + 1, num_channels, 1, 1 + ) + variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape( + 1, num_channels, 1, 1 + ) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5) + assert np.allclose(batch_norm.running_var, running_var, atol=1e-5) + data = normalized_data + def test_conv1d(self): N = 5 L = 12