diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 17e6ab669..cee59d6b3 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -227,7 +227,7 @@ class BatchNorm(Module): self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats - self.dims_expanded = False + self._dims_expanded = False if self.affine: self.weight = mx.ones((num_features,)) @@ -268,7 +268,7 @@ class BatchNorm(Module): self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes) self.running_var = mx.expand_dims(self.running_var, self.reduction_axes) - self.dims_expanded = True + self._dims_expanded = True def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -304,7 +304,7 @@ class BatchNorm(Module): mx.array: Output tensor. """ - if not self.dims_expanded: + if not self._dims_expanded: self._check_and_expand_dims(x) if self.training or not self.track_running_stats: