improve batch norm code ^^

This commit is contained in:
m0saan 2023-12-22 20:58:54 +01:00
parent b444a6a693
commit 019a85511c

View File

@ -227,7 +227,7 @@ class BatchNorm(Module):
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.dims_expanded = False
self._dims_expanded = False
if self.affine:
self.weight = mx.ones((num_features,))
@ -268,7 +268,7 @@ class BatchNorm(Module):
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
self.dims_expanded = True
self._dims_expanded = True
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
@ -304,7 +304,7 @@ class BatchNorm(Module):
mx.array: Output tensor.
"""
if not self.dims_expanded:
if not self._dims_expanded:
self._check_and_expand_dims(x)
if self.training or not self.track_running_stats: