From c68a472b833743194c155235cb6edce356c2e2ba Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Sat, 23 Dec 2023 15:20:08 +0100 Subject: [PATCH] Update python/mlx/nn/layers/normalization.py Co-authored-by: Robert McCraith --- python/mlx/nn/layers/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index cee59d6b3..1cff5af56 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -234,8 +234,8 @@ class BatchNorm(Module): self.bias = mx.zeros((num_features,)) if self.track_running_stats: - self.running_mean = mx.zeros((num_features,)) - self.running_var = mx.ones((num_features,)) + self._running_mean = mx.zeros((num_features,)) + self._running_var = mx.ones((num_features,)) def _extra_repr(self): return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"