From 019a85511cf3394d9013a67de83bc26cfc0c1b29 Mon Sep 17 00:00:00 2001 From: m0saan Date: Fri, 22 Dec 2023 20:58:54 +0100 Subject: [PATCH] improve batch norm code ^^ --- python/mlx/nn/layers/normalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 17e6ab669..cee59d6b3 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -227,7 +227,7 @@ class BatchNorm(Module): self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats - self.dims_expanded = False + self._dims_expanded = False if self.affine: self.weight = mx.ones((num_features,)) @@ -268,7 +268,7 @@ class BatchNorm(Module): self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes) self.running_var = mx.expand_dims(self.running_var, self.reduction_axes) - self.dims_expanded = True + self._dims_expanded = True def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -304,7 +304,7 @@ class BatchNorm(Module): mx.array: Output tensor. """ - if not self.dims_expanded: + if not self._dims_expanded: self._check_and_expand_dims(x) if self.training or not self.track_running_stats: