diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index cee59d6b3..1cff5af56 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -234,8 +234,8 @@ class BatchNorm(Module): self.bias = mx.zeros((num_features,)) if self.track_running_stats: - self.running_mean = mx.zeros((num_features,)) - self.running_var = mx.ones((num_features,)) + self._running_mean = mx.zeros((num_features,)) + self._running_var = mx.ones((num_features,)) def _extra_repr(self): return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"